File size: 29,688 Bytes
0930e10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
"""
=============================================================================
DCDE: Depth-Conditioned Dynamic Ensemble with Evidential Uncertainty
for Femtosecond Laser Internal Hydrogel Etching Prediction

A novel hybrid architecture combining:
1. FiLM-conditioned Neural Network (depth-adaptive feature modulation)
2. XGBoost gradient-boosted trees (capturing tabular feature interactions)
3. Learned dynamic gating network (input-conditioned fusion)
4. Evidential Deep Learning (Normal-Inverse-Gamma uncertainty)
5. Physics-informed regularization (monotonicity + energy constraints)

References:
- FiLM: Perez et al., AAAI 2018 (arxiv:1709.07871)
- Deep Evidential Regression: Amini et al., NeurIPS 2020 (arxiv:1910.02600)
- DELE gating: AAAI 2023 (arxiv:2302.00932)
- Physics-informed ML: Zhang et al. 2022 (arxiv:2211.08064)
=============================================================================
"""

from __future__ import annotations

import math
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


# =============================================================================
# 1. PHYSICS-INFORMED FEATURE ENGINEERING (Depth-Dependent)
# =============================================================================

class DepthPhysicsFeatures:
    """
    Compute analytically-derived physics features that encode how
    femtosecond laser behavior changes with focusing depth in hydrogels.
    
    These features capture three primary depth-dependent effects:
    1. Spherical aberration (Strehl ratio degradation)
    2. Group velocity dispersion (pulse temporal broadening)
    3. Self-focusing proximity (Kerr nonlinearity regime)
    
    Scientific basis:
    - Vogel et al., Applied Physics B (2005) - fs-laser tissue interaction
    - Schaffer et al., Optics Letters (2001) - bulk modification thresholds
    - Boyd, Nonlinear Optics (2020) - self-focusing, GVD theory
    """
    
    def __init__(
        self,
        n_medium: float = 1.34,      # Refractive index of hydrogel
        beta2_fs2_mm: float = 55.0,  # GVD parameter (fs²/mm) for water-like medium
        n2_m2_W: float = 2.0e-20,    # Nonlinear refractive index (m²/W)
    ):
        self.n_medium = n_medium
        self.beta2 = beta2_fs2_mm * 1e-30 / 1e-3  # Convert to s²/m
        self.n2 = n2_m2_W
    
    def compute(
        self,
        focusing_depth_um: np.ndarray,
        pulse_duration_fs: np.ndarray,
        wavelength_nm: np.ndarray,
        NA: np.ndarray,
        power_mW: np.ndarray,
        rep_rate_kHz: np.ndarray,
    ) -> np.ndarray:
        """
        Compute physics features from raw parameters.
        
        Returns array of shape (N, 5) with columns:
        [strehl_ratio, intensity_factor, z_normalized, self_focus_ratio, depth_aberration]
        """
        z = np.asarray(focusing_depth_um) * 1e-6  # µm → m
        tau0 = np.asarray(pulse_duration_fs) * 1e-15  # fs → s
        lam = np.asarray(wavelength_nm) * 1e-9  # nm → m
        na = np.asarray(NA)
        P_avg = np.asarray(power_mW) * 1e-3  # mW → W
        f_rep = np.asarray(rep_rate_kHz) * 1e3  # kHz → Hz
        
        # 1. Strehl ratio: S(z) = exp(-(2π·Δn·z·NA²/λ)²)
        # Quantifies how much aberration degrades the focal spot
        delta_n = self.n_medium - 1.0  # Air-hydrogel RI mismatch
        strehl = np.exp(-((2 * np.pi * delta_n * z * na**2) / lam)**2)
        strehl = np.clip(strehl, 1e-6, 1.0)
        
        # 2. GVD pulse broadening: τ(z) = τ₀·√(1 + (z/L_D)²)
        # Reduced peak intensity at depth
        L_D = tau0**2 / np.abs(self.beta2)  # Dispersion length
        tau_z = tau0 * np.sqrt(1 + (z / np.maximum(L_D, 1e-10))**2)
        intensity_factor = tau0 / np.maximum(tau_z, tau0)  # ∈ (0, 1]
        
        # 3. Normalized depth (relative to Rayleigh range)
        # Indicates when geometric vs. wave-optical effects dominate
        w0 = lam / (np.pi * np.maximum(na, 0.01))  # Beam waist
        z_rayleigh = np.pi * w0**2 / lam
        z_normalized = z / np.maximum(z_rayleigh, 1e-10)
        
        # 4. Self-focusing proximity: P_peak / P_critical
        # When > 1: catastrophic self-focusing regime
        P_peak = P_avg / (f_rep * tau0)  # Peak power per pulse
        P_cr = 3.77 * lam**2 / (8 * np.pi * self.n_medium * self.n2)
        sf_ratio = P_peak / np.maximum(P_cr, 1e-10)
        sf_ratio = np.clip(sf_ratio, 0, 50)  # Cap at 50× critical
        
        # 5. Depth-dependent aberration parameter
        # Combined effect: how much the focal volume degrades with depth
        depth_aberration = delta_n * z * na**2 / lam
        
        return np.column_stack([
            strehl,
            intensity_factor,
            z_normalized,
            sf_ratio,
            depth_aberration,
        ]).astype(np.float32)
    
    @property
    def feature_names(self) -> List[str]:
        return [
            "strehl_ratio",
            "intensity_factor_gvd",
            "z_normalized_rayleigh",
            "self_focusing_ratio",
            "depth_aberration_param",
        ]


# =============================================================================
# 2. FiLM-CONDITIONED NEURAL NETWORK (Depth-Adaptive)
# =============================================================================

class FiLMGenerator(nn.Module):
    """
    Feature-wise Linear Modulation (FiLM) generator.
    
    Maps conditioning input (depth features) to per-layer (γ, β) pairs
    that modulate hidden representations: h' = γ ⊙ h + β
    
    Uses the Δγ initialization trick: γ = 1 + Δγ for stable training
    (identity modulation at initialization).
    
    Reference: Perez et al., "FiLM: Visual Reasoning with a General 
    Conditioning Layer", AAAI 2018.
    """
    
    def __init__(self, conditioning_dim: int, hidden_dims: List[int]):
        super().__init__()
        self.generators = nn.ModuleList()
        
        for h_dim in hidden_dims:
            self.generators.append(
                nn.Sequential(
                    nn.Linear(conditioning_dim, 64),
                    nn.SiLU(),
                    nn.Linear(64, h_dim * 2),  # γ and β
                )
            )
        
        # Initialize near identity (Δγ ≈ 0, β ≈ 0)
        for gen in self.generators:
            nn.init.zeros_(gen[-1].weight)
            nn.init.zeros_(gen[-1].bias)
    
    def forward(self, conditioning: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """
        Parameters
        ----------
        conditioning : Tensor, shape (B, conditioning_dim)
            Depth-related features for conditioning
        
        Returns
        -------
        list of (gamma, beta) tuples for each layer
        """
        film_params = []
        for gen in self.generators:
            params = gen(conditioning)
            h_dim = params.shape[-1] // 2
            delta_gamma = params[:, :h_dim]
            beta = params[:, h_dim:]
            gamma = 1.0 + delta_gamma  # Δγ trick
            film_params.append((gamma, beta))
        return film_params


class FiLMConditionedMLP(nn.Module):
    """
    Multi-layer perceptron with FiLM conditioning at each hidden layer.
    
    Architecture:
        Input → [Linear → BatchNorm → FiLM(γ,β) → SiLU → Dropout] × L → Output
    
    The FiLM conditioning allows depth information to modulate the network's
    intermediate representations multiplicatively, enabling fundamentally
    different processing depending on focusing depth — not just adding depth
    as another input feature.
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        conditioning_dim: int,
        dropout: float = 0.15,
    ):
        super().__init__()
        self.hidden_dims = hidden_dims
        
        # Build layers
        dims = [input_dim] + hidden_dims
        self.layers = nn.ModuleList([
            nn.Linear(d_in, d_out) for d_in, d_out in zip(dims[:-1], dims[1:])
        ])
        self.batch_norms = nn.ModuleList([
            nn.BatchNorm1d(d) for d in hidden_dims
        ])
        self.dropouts = nn.ModuleList([
            nn.Dropout(dropout * (1 - i / len(hidden_dims)))
            for i in range(len(hidden_dims))
        ])
        
        # FiLM generator (depth → modulation parameters)
        self.film_generator = FiLMGenerator(conditioning_dim, hidden_dims)
        
        # Output projection
        self.output_layer = nn.Linear(hidden_dims[-1], output_dim)
    
    def forward(
        self,
        x: torch.Tensor,
        conditioning: torch.Tensor,
    ) -> torch.Tensor:
        """
        Parameters
        ----------
        x : Tensor (B, input_dim) - laser + material features
        conditioning : Tensor (B, conditioning_dim) - depth physics features
        
        Returns
        -------
        Tensor (B, output_dim) - latent representation
        """
        # Get FiLM parameters for all layers
        film_params = self.film_generator(conditioning)
        
        h = x
        for i, (layer, bn, dropout) in enumerate(
            zip(self.layers, self.batch_norms, self.dropouts)
        ):
            h = layer(h)
            h = bn(h)
            # Apply FiLM modulation
            gamma, beta = film_params[i]
            h = gamma * h + beta
            h = F.silu(h)
            h = dropout(h)
        
        return self.output_layer(h)


# =============================================================================
# 3. EVIDENTIAL REGRESSION HEAD (Normal-Inverse-Gamma)
# =============================================================================

class EvidentialHead(nn.Module):
    """
    Normal-Inverse-Gamma (NIG) evidential regression head.
    
    Outputs four parameters per target that parameterize a NIG distribution,
    providing both aleatoric and epistemic uncertainty estimates in a single
    forward pass (no ensemble or MC dropout required).
    
    For each output dimension:
        μ ~ N(γ, σ²/ν)        [predictive mean with epistemic noise]
        σ² ~ InvGamma(α, β)   [aleatoric variance]
    
    Uncertainty decomposition:
        Aleatoric:  E[σ²] = β / (α - 1)
        Epistemic:  Var[μ] = β / (ν(α - 1))
    
    Reference: Amini et al., "Deep Evidential Regression", NeurIPS 2020.
    """
    
    def __init__(self, input_dim: int, n_outputs: int):
        super().__init__()
        self.n_outputs = n_outputs
        # Output: 4 parameters per target (γ, ν, α, β)
        self.fc = nn.Linear(input_dim, n_outputs * 4)
        
        # Initialize carefully for stable NIG parameters
        nn.init.xavier_normal_(self.fc.weight, gain=0.1)
        nn.init.zeros_(self.fc.bias)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """
        Returns
        -------
        gamma : Tensor (B, n_outputs) - predictive mean
        nu : Tensor (B, n_outputs) - evidence for mean (>0)
        alpha : Tensor (B, n_outputs) - evidence for variance (>1)
        beta : Tensor (B, n_outputs) - scale for variance (>0)
        """
        out = self.fc(x).reshape(-1, self.n_outputs, 4)
        
        gamma = out[..., 0]
        nu = F.softplus(out[..., 1]) + 1e-6       # ν > 0
        alpha = F.softplus(out[..., 2]) + 1.0 + 1e-6  # α > 1
        beta = F.softplus(out[..., 3]) + 1e-6      # β > 0
        
        return gamma, nu, alpha, beta
    
    @staticmethod
    def aleatoric_uncertainty(alpha: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
        """E[σ²] = β / (α - 1)"""
        return beta / (alpha - 1.0).clamp(min=1e-6)
    
    @staticmethod
    def epistemic_uncertainty(nu: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
        """Var[μ] = β / (ν(α - 1))"""
        return beta / (nu * (alpha - 1.0).clamp(min=1e-6))


# =============================================================================
# 4. DEPTH-CONDITIONED GATING NETWORK (Learned Dynamic Fusion)
# =============================================================================

class DepthConditionedGatingNetwork(nn.Module):
    """
    Input-conditioned gating network that dynamically determines how to
    fuse XGBoost and Neural Network predictions.
    
    Unlike a fixed 60/40 weighting, this network learns WHEN each expert
    is more reliable — conditioned on both input features and focusing depth.
    
    Key insight from DELE (arxiv:2302.00932): the gating network benefits
    from seeing the same features as the experts, plus the experts' own
    predictions as additional input.
    
    Architecture:
        [input_features ⊕ depth_physics ⊕ expert_predictions] → MLP → softmax(2)
    """
    
    def __init__(
        self,
        input_dim: int,
        depth_dim: int,
        n_expert_outputs: int,
        n_experts: int = 2,
        hidden_dim: int = 64,
    ):
        super().__init__()
        total_input = input_dim + depth_dim + n_expert_outputs * n_experts
        
        self.gate = nn.Sequential(
            nn.Linear(total_input, hidden_dim),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, n_experts),
        )
        
        # Temperature parameter (learnable) for softmax sharpness
        self.temperature = nn.Parameter(torch.ones(1))
    
    def forward(
        self,
        features: torch.Tensor,
        depth_physics: torch.Tensor,
        expert_preds: List[torch.Tensor],
    ) -> torch.Tensor:
        """
        Parameters
        ----------
        features : Tensor (B, input_dim)
        depth_physics : Tensor (B, depth_dim)
        expert_preds : list of Tensor (B, n_outputs) per expert
        
        Returns
        -------
        weights : Tensor (B, n_experts) - softmax weights summing to 1
        """
        gate_input = torch.cat(
            [features, depth_physics] + expert_preds, dim=-1
        )
        logits = self.gate(gate_input) / self.temperature.clamp(min=0.1)
        return F.softmax(logits, dim=-1)


# =============================================================================
# 5. COMPLETE DCDE MODEL
# =============================================================================

class DCDE(nn.Module):
    """
    Depth-Conditioned Dynamic Ensemble (DCDE)
    
    A hybrid architecture for predicting femtosecond laser internal etching
    geometry in hydrogels. Combines:
    
    1. XGBoost branch: Pre-trained gradient-boosted trees capturing 
       complex tabular feature interactions (frozen during DCDE training)
    
    2. FiLM-NN branch: Depth-conditioned neural network where focusing 
       depth modulates intermediate representations via FiLM layers
    
    3. Dynamic gating: Input-conditioned fusion network that learns 
       optimal weighting between branches depending on input regime
    
    4. Evidential head: NIG distribution output providing calibrated 
       aleatoric + epistemic uncertainty
    
    5. Physics-informed loss: Soft monotonicity constraints and energy 
       conservation regularization
    
    Training protocol (3-phase, following DELE):
        Phase 1: Train XGBoost independently on tabular features
        Phase 2: Train FiLM-NN with evidential head (XGBoost frozen)
        Phase 3: Train gating network jointly (optionally fine-tune FiLM-NN)
    """
    
    def __init__(
        self,
        input_dim: int,
        depth_physics_dim: int = 5,
        hidden_dims: List[int] = [128, 96, 64],
        n_outputs: int = 5,
        n_experts: int = 2,
        gating_hidden: int = 64,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.n_outputs = n_outputs
        
        # FiLM-conditioned NN branch
        self.film_nn = FiLMConditionedMLP(
            input_dim=input_dim,
            hidden_dims=hidden_dims,
            output_dim=hidden_dims[-1],
            conditioning_dim=depth_physics_dim,
        )
        
        # XGBoost prediction embedding (projects XGB outputs to latent space)
        self.xgb_embed = nn.Sequential(
            nn.Linear(n_outputs, hidden_dims[-1]),
            nn.SiLU(),
            nn.Linear(hidden_dims[-1], hidden_dims[-1]),
        )
        
        # Gating network
        self.gating = DepthConditionedGatingNetwork(
            input_dim=input_dim,
            depth_dim=depth_physics_dim,
            n_expert_outputs=n_outputs,
            n_experts=n_experts,
            hidden_dim=gating_hidden,
        )
        
        # Evidential head (NIG parameters)
        self.evidential_head = EvidentialHead(hidden_dims[-1], n_outputs)
        
        # Direct output head for XGBoost branch (for gating comparison)
        self.xgb_output = nn.Linear(hidden_dims[-1], n_outputs)
    
    def forward(
        self,
        features: torch.Tensor,
        depth_physics: torch.Tensor,
        xgb_predictions: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        features : Tensor (B, input_dim) - all input features
        depth_physics : Tensor (B, depth_physics_dim) - computed physics features
        xgb_predictions : Tensor (B, n_outputs) - pre-computed XGBoost predictions
        
        Returns
        -------
        dict with keys:
            'gamma' : predictive mean (B, n_outputs)
            'nu', 'alpha', 'beta' : NIG parameters
            'aleatoric_unc' : aleatoric uncertainty
            'epistemic_unc' : epistemic uncertainty
            'gate_weights' : expert weights (B, 2)
            'nn_pred' : raw NN branch prediction
            'xgb_pred' : embedded XGBoost prediction
        """
        # NN branch: depth-conditioned via FiLM
        nn_latent = self.film_nn(features, depth_physics)
        
        # XGBoost branch: embed predictions into latent space
        xgb_latent = self.xgb_embed(xgb_predictions)
        
        # Compute intermediate predictions for gating input
        nn_pred_raw = self.evidential_head(nn_latent)[0]  # Just gamma
        
        # Dynamic gating: determine expert weights
        gate_weights = self.gating(
            features, depth_physics,
            [xgb_predictions, nn_pred_raw.detach()]  # Detach to avoid circular gradients
        )
        
        # Fused latent representation
        w_xgb = gate_weights[:, 0:1]  # (B, 1)
        w_nn = gate_weights[:, 1:2]   # (B, 1)
        fused_latent = w_xgb * xgb_latent + w_nn * nn_latent
        
        # Evidential output
        gamma, nu, alpha, beta = self.evidential_head(fused_latent)
        
        # Uncertainty decomposition
        aleatoric = EvidentialHead.aleatoric_uncertainty(alpha, beta)
        epistemic = EvidentialHead.epistemic_uncertainty(nu, alpha, beta)
        
        return {
            "gamma": gamma,
            "nu": nu,
            "alpha": alpha,
            "beta": beta,
            "aleatoric_unc": aleatoric,
            "epistemic_unc": epistemic,
            "gate_weights": gate_weights,
            "nn_pred": nn_pred_raw,
            "xgb_pred": xgb_predictions,
        }


# =============================================================================
# 6. LOSS FUNCTIONS (NIG + Physics-Informed)
# =============================================================================

class DCDELoss(nn.Module):
    """
    Composite loss for DCDE training:
    
    L_total = L_NIG + λ_mono·L_monotonicity + λ_energy·L_energy + λ_gate·L_gate_entropy
    
    Components:
    1. NIG Loss (evidential regression) - primary data fitting
    2. Monotonicity loss - enforces physical depth-etch relationships
    3. Energy conservation - volume scales with deposited energy
    4. Gate entropy regularization - prevents degenerate gating
    """
    
    def __init__(
        self,
        lambda_nig_reg: float = 0.01,
        lambda_mono: float = 0.05,
        lambda_energy: float = 0.02,
        lambda_gate: float = 0.01,
        depth_feature_idx: int = -1,
        power_feature_idx: int = 0,
    ):
        super().__init__()
        self.lambda_nig_reg = lambda_nig_reg
        self.lambda_mono = lambda_mono
        self.lambda_energy = lambda_energy
        self.lambda_gate = lambda_gate
        self.depth_idx = depth_feature_idx
        self.power_idx = power_feature_idx
    
    def nig_loss(
        self,
        y: torch.Tensor,
        gamma: torch.Tensor,
        nu: torch.Tensor,
        alpha: torch.Tensor,
        beta: torch.Tensor,
    ) -> torch.Tensor:
        """
        Normal-Inverse-Gamma negative log-likelihood with evidence regularization.
        
        L = L_NLL + λ·L_evidence_regularization
        
        The regularization penalizes high evidence (ν, α) when the prediction
        is wrong, encouraging the model to be uncertain when inaccurate.
        """
        # NLL term
        omega = 2 * beta * (1 + nu)
        nll = (
            0.5 * torch.log(torch.pi / nu.clamp(min=1e-6))
            - alpha * torch.log(omega.clamp(min=1e-10))
            + (alpha + 0.5) * torch.log(
                ((y - gamma) ** 2 * nu + omega).clamp(min=1e-10)
            )
            + torch.lgamma(alpha) - torch.lgamma(alpha + 0.5)
        )
        
        # Evidence regularization (penalize evidence when wrong)
        error = torch.abs(y - gamma)
        evidence = 2 * nu + alpha
        reg = error * evidence
        
        return (nll + self.lambda_nig_reg * reg).mean()
    
    def monotonicity_loss(
        self,
        features: torch.Tensor,
        gamma: torch.Tensor,
        model: nn.Module,
        depth_physics: torch.Tensor,
        xgb_pred: torch.Tensor,
    ) -> torch.Tensor:
        """
        Soft monotonicity constraint: for most targets, increasing laser
        parameters (power, passes) at fixed depth should not decrease output.
        
        Specifically for depth etching:
        - More passes → deeper etch (target 0: etch_depth)
        - Higher fluence → wider etch (target 1: etch_width)
        
        Implemented as finite-difference gradient penalty.
        """
        # Perturb power upward by small amount
        features_perturbed = features.clone()
        features_perturbed[:, self.power_idx] = features[:, self.power_idx] * 1.05
        
        # Get predictions for perturbed input
        with torch.no_grad():
            output_perturbed = model(features_perturbed, depth_physics, xgb_pred)
        
        # Depth and width should increase with power (soft constraint)
        # Only penalize violations (relu of negative gradient)
        violation_depth = F.relu(gamma[:, 0] - output_perturbed["gamma"][:, 0])
        violation_width = F.relu(gamma[:, 1] - output_perturbed["gamma"][:, 1])
        
        return (violation_depth.mean() + violation_width.mean()) / 2
    
    def energy_conservation_loss(
        self,
        features: torch.Tensor,
        gamma: torch.Tensor,
    ) -> torch.Tensor:
        """
        Soft energy constraint: predicted ablated volume should correlate
        positively with deposited energy.
        
        Volume proxy ∝ depth × width²
        Energy proxy ∝ power × (num_passes / scan_speed)
        
        We penalize anti-correlation (negative cosine similarity).
        """
        # Volume proxy from predictions
        depth_pred = gamma[:, 0].clamp(min=0)
        width_pred = gamma[:, 1].clamp(min=0)
        volume_proxy = depth_pred * width_pred ** 2
        
        # Energy proxy from inputs
        power = features[:, self.power_idx].clamp(min=1e-6)
        energy_proxy = power  # Simplified; could include scan speed, passes
        
        # Penalize negative correlation
        # Cosine similarity should be positive
        cos_sim = F.cosine_similarity(
            volume_proxy.unsqueeze(-1),
            energy_proxy.unsqueeze(-1),
            dim=0,
        )
        return F.relu(-cos_sim).mean()
    
    def gate_entropy_loss(self, gate_weights: torch.Tensor) -> torch.Tensor:
        """
        Encourage non-degenerate gating (not always choosing one expert).
        Maximize entropy of gate weights (encourage exploration).
        Penalize when one weight is always 0 or 1.
        """
        # Per-sample entropy
        entropy = -(gate_weights * torch.log(gate_weights + 1e-8)).sum(dim=-1)
        # Maximize entropy → minimize negative entropy
        max_entropy = math.log(gate_weights.shape[-1])
        return (max_entropy - entropy.mean())
    
    def forward(
        self,
        y: torch.Tensor,
        model_output: Dict[str, torch.Tensor],
        features: torch.Tensor,
        depth_physics: torch.Tensor,
        model: Optional[nn.Module] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute total loss with all components.
        
        Returns dict with individual loss components for logging.
        """
        gamma = model_output["gamma"]
        nu = model_output["nu"]
        alpha = model_output["alpha"]
        beta = model_output["beta"]
        gate_weights = model_output["gate_weights"]
        xgb_pred = model_output["xgb_pred"]
        
        # Primary loss: NIG
        l_nig = self.nig_loss(y, gamma, nu, alpha, beta)
        
        # Physics losses
        l_mono = torch.tensor(0.0, device=y.device)
        if model is not None and self.lambda_mono > 0:
            l_mono = self.monotonicity_loss(features, gamma, model, depth_physics, xgb_pred)
        
        l_energy = torch.tensor(0.0, device=y.device)
        if self.lambda_energy > 0:
            l_energy = self.energy_conservation_loss(features, gamma)
        
        # Gating regularization
        l_gate = self.gate_entropy_loss(gate_weights)
        
        # Total
        total = (
            l_nig
            + self.lambda_mono * l_mono
            + self.lambda_energy * l_energy
            + self.lambda_gate * l_gate
        )
        
        return {
            "total": total,
            "nig": l_nig,
            "monotonicity": l_mono,
            "energy": l_energy,
            "gate_entropy": l_gate,
        }


# =============================================================================
# 7. TRAINING UTILITIES
# =============================================================================

class DCDETrainer:
    """
    Three-phase training protocol for DCDE.
    
    Phase 1: Train XGBoost on tabular features (external, uses sklearn/xgboost)
    Phase 2: Train FiLM-NN with evidential head (XGBoost predictions as input)
    Phase 3: Train gating network + fine-tune FiLM-NN end-to-end
    """
    
    def __init__(
        self,
        model: DCDE,
        loss_fn: DCDELoss,
        lr_phase2: float = 1e-3,
        lr_phase3: float = 3e-4,
        weight_decay: float = 1e-4,
        device: str = "cpu",
    ):
        self.model = model.to(device)
        self.loss_fn = loss_fn
        self.lr_phase2 = lr_phase2
        self.lr_phase3 = lr_phase3
        self.weight_decay = weight_decay
        self.device = device
    
    def phase2_train_step(
        self,
        features: torch.Tensor,
        depth_physics: torch.Tensor,
        xgb_predictions: torch.Tensor,
        targets: torch.Tensor,
        optimizer: torch.optim.Optimizer,
    ) -> Dict[str, float]:
        """Single training step for Phase 2 (FiLM-NN + evidential head)."""
        self.model.train()
        optimizer.zero_grad()
        
        output = self.model(features, depth_physics, xgb_predictions)
        losses = self.loss_fn(targets, output, features, depth_physics, self.model)
        
        losses["total"].backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        optimizer.step()
        
        return {k: v.item() for k, v in losses.items()}
    
    def phase3_train_step(
        self,
        features: torch.Tensor,
        depth_physics: torch.Tensor,
        xgb_predictions: torch.Tensor,
        targets: torch.Tensor,
        optimizer: torch.optim.Optimizer,
    ) -> Dict[str, float]:
        """Single training step for Phase 3 (end-to-end with gating)."""
        # Same as phase 2 but with different learning rate and all params unfrozen
        return self.phase2_train_step(features, depth_physics, xgb_predictions, targets, optimizer)
    
    @torch.no_grad()
    def predict(
        self,
        features: torch.Tensor,
        depth_physics: torch.Tensor,
        xgb_predictions: torch.Tensor,
    ) -> Dict[str, np.ndarray]:
        """
        Inference with uncertainty quantification.
        
        Returns
        -------
        dict with:
            'mean': predicted values (B, n_outputs)
            'aleatoric_unc': aleatoric uncertainty per target
            'epistemic_unc': epistemic uncertainty per target
            'total_unc': total predictive uncertainty
            'gate_weights': expert weights showing XGB vs NN dominance
        """
        self.model.eval()
        output = self.model(features, depth_physics, xgb_predictions)
        
        return {
            "mean": output["gamma"].cpu().numpy(),
            "aleatoric_unc": output["aleatoric_unc"].cpu().numpy(),
            "epistemic_unc": output["epistemic_unc"].cpu().numpy(),
            "total_unc": (output["aleatoric_unc"] + output["epistemic_unc"]).cpu().numpy(),
            "gate_weights": output["gate_weights"].cpu().numpy(),
        }