File size: 17,675 Bytes
d92d8cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
/**
 * UGTC: Uncertainty-Gated Temporal Credit β€” Java Reference Implementation
 * =========================================================================
 *
 * A pure Java reference implementation of the UGTC module.
 * No external dependencies. Uses simple float[][] arrays for matrix ops.
 *
 * This is a reference implementation for readability and portability.
 * For production use, consider using a Java deep learning framework
 * (DL4J, PyTorch Java API) with proper GPU support.
 *
 * Paper: https://doi.org/10.5281/zenodo.19715116
 */

package ai.ethosoft.ugtc;

import java.util.Random;

public class UGTCModule {

    // ──────────────────────────────────────────────────────────────────────────
    // Configuration
    // ──────────────────────────────────────────────────────────────────────────

    public static class Config {
        public int   hiddenDim    = 64;
        public int   M            = 3;
        public float lambdaFast   = 0.80f;
        public float lambdaSlow   = 0.99f;
        public float beta         = 5.0f;
        public float emaMomentum  = 0.99f;
        public float eps          = 1e-8f;

        public Config() {}

        public Config(int hiddenDim, int M, float lambdaFast, float lambdaSlow,
                      float beta, float emaMomentum) {
            this.hiddenDim   = hiddenDim;
            this.M           = M;
            this.lambdaFast  = lambdaFast;
            this.lambdaSlow  = lambdaSlow;
            this.beta        = beta;
            this.emaMomentum = emaMomentum;
        }
    }

    // ──────────────────────────────────────────────────────────────────────────
    // Linear layer
    // ──────────────────────────────────────────────────────────────────────────

    private static class Linear {
        final float[][] W;   // [outDim][inDim]
        final float[]   b;   // [outDim]

        Linear(int inDim, int outDim, Random rng) {
            W = new float[outDim][inDim];
            b = new float[outDim];
            float scale = (float) Math.sqrt(2.0 / inDim);
            for (int i = 0; i < outDim; i++)
                for (int j = 0; j < inDim; j++)
                    W[i][j] = rng.nextGaussian() > 0 ? scale : -scale;
        }

        float[] forward(float[] x) {
            int outDim = W.length;
            int inDim  = x.length;
            float[] out = new float[outDim];
            for (int i = 0; i < outDim; i++) {
                out[i] = b[i];
                for (int j = 0; j < inDim; j++)
                    out[i] += W[i][j] * x[j];
            }
            return out;
        }
    }

    // ──────────────────────────────────────────────────────────────────────────
    // Value network: obs β†’ h β†’ h β†’ scalar
    // ──────────────────────────────────────────────────────────────────────────

    private static class ValueNetwork {
        final Linear fc1, fc2, fc3;

        ValueNetwork(int obsDim, int hiddenDim, Random rng) {
            fc1 = new Linear(obsDim,    hiddenDim, rng);
            fc2 = new Linear(hiddenDim, hiddenDim, rng);
            fc3 = new Linear(hiddenDim, 1,         rng);
        }

        float forward(float[] obs) {
            float[] h1 = applyTanh(fc1.forward(obs));
            float[] h2 = applyTanh(fc2.forward(h1));
            return fc3.forward(h2)[0];
        }
    }

    // ──────────────────────────────────────────────────────────────────────────
    // Ensemble value network
    // ──────────────────────────────────────────────────────────────────────────

    private static class EnsembleValueNetwork {
        final ValueNetwork[] members;

        EnsembleValueNetwork(int obsDim, int hiddenDim, int M, Random rng) {
            members = new ValueNetwork[M];
            for (int i = 0; i < M; i++)
                members[i] = new ValueNetwork(obsDim, hiddenDim, rng);
        }

        /** @return float[] {mean, std} of ensemble predictions */
        float[] forward(float[] obs) {
            int M = members.length;
            float[] vals = new float[M];
            float mean   = 0.0f;

            for (int i = 0; i < M; i++) {
                vals[i] = members[i].forward(obs);
                mean    += vals[i];
            }
            mean /= M;

            float var = 0.0f;
            for (float v : vals) var += (v - mean) * (v - mean);
            var /= (M > 1 ? M - 1 : 1);

            return new float[]{ mean, (float) Math.sqrt(var) };
        }
    }

    // ──────────────────────────────────────────────────────────────────────────
    // Gate result
    // ──────────────────────────────────────────────────────────────────────────

    public static class GateResult {
        public final float gate;
        public final float vFast;
        public final float vSlow;
        public final float sigma;

        GateResult(float gate, float vFast, float vSlow, float sigma) {
            this.gate  = gate;
            this.vFast = vFast;
            this.vSlow = vSlow;
            this.sigma = sigma;
        }
    }

    // ──────────────────────────────────────────────────────────────────────────
    // Module fields
    // ──────────────────────────────────────────────────────────────────────────

    private final Config config;
    private final ValueNetwork fastCritic;
    private final EnsembleValueNetwork slowEnsemble;
    private float sigmaEMA;

    // ──────────────────────────────────────────────────────────────────────────
    // Constructor
    // ──────────────────────────────────────────────────────────────────────────

    public UGTCModule(int obsDim) {
        this(obsDim, new Config());
    }

    public UGTCModule(int obsDim, Config config) {
        this.config = config;
        this.sigmaEMA = 1.0f;
        Random rng = new Random(42);
        this.fastCritic   = new ValueNetwork(obsDim, config.hiddenDim, rng);
        this.slowEnsemble = new EnsembleValueNetwork(obsDim, config.hiddenDim, config.M, rng);
    }

    // ──────────────────────────────────────────────────────────────────────────
    // Gate computation
    // ──────────────────────────────────────────────────────────────────────────

    /**
     * Compute the uncertainty gate u(s) for a single observation.
     *
     * u(s) = sigmoid(-Ξ² Β· (ΟƒΜ‚(s) - 1))
     * where ΟƒΜ‚(s) = Οƒ(s) / Οƒ_EMA
     *
     * @param obs   Observation vector
     * @param train Whether to update EMA statistics
     * @return GateResult containing gate, v_fast, v_slow, sigma
     */
    public GateResult computeGate(float[] obs, boolean train) {
        float vFast     = fastCritic.forward(obs);
        float[] ensOut  = slowEnsemble.forward(obs);
        float vSlow     = ensOut[0];
        float sigma     = ensOut[1];

        if (train) {
            sigmaEMA = config.emaMomentum * sigmaEMA
                     + (1.0f - config.emaMomentum) * sigma;
        }

        float normalizedSigma = sigma / (sigmaEMA + config.eps);
        float gate = sigmoid(-config.beta * (normalizedSigma - 1.0f));

        return new GateResult(gate, vFast, vSlow, sigma);
    }

    // ──────────────────────────────────────────────────────────────────────────
    // Value estimation
    // ──────────────────────────────────────────────────────────────────────────

    /**
     * Blended value estimate: V^UGTC(s) = u(s)Β·VΜ„_slow(s) + (1-u(s))Β·V_fast(s)
     *
     * @param obs   Observation vector
     * @param train Whether to update EMA
     * @return Scalar blended value
     */
    public float getValueUGTC(float[] obs, boolean train) {
        GateResult r = computeGate(obs, train);
        return r.gate * r.vSlow + (1.0f - r.gate) * r.vFast;
    }

    // ──────────────────────────────────────────────────────────────────────────
    // GAE computation
    // ──────────────────────────────────────────────────────────────────────────

    /**
     * Standard Generalized Advantage Estimation.
     *
     * Ξ΄β‚œ = rβ‚œ + Ξ³Β·V(sβ‚œβ‚Šβ‚)Β·(1-dβ‚œ) - V(sβ‚œ)
     * Aβ‚œ = Ξ΄β‚œ + γλ·(1-dβ‚œ)Β·Aβ‚œβ‚Šβ‚
     *
     * @param rewards   Reward sequence
     * @param values    Current-state values
     * @param nextVals  Next-state values
     * @param dones     Episode termination flags (1.0 = done)
     * @param gamma     Discount factor
     * @param lam       GAE lambda
     * @return Array of advantage estimates
     */
    public static float[] computeGAE(
        float[] rewards, float[] values, float[] nextVals, float[] dones,
        float gamma, float lam
    ) {
        int T = rewards.length;
        float[] advantages = new float[T];
        float gae = 0.0f;

        for (int t = T - 1; t >= 0; t--) {
            float delta = rewards[t] + gamma * nextVals[t] * (1.0f - dones[t]) - values[t];
            gae = delta + gamma * lam * (1.0f - dones[t]) * gae;
            advantages[t] = gae;
        }
        return advantages;
    }

    // ──────────────────────────────────────────────────────────────────────────
    // UGTC advantage computation
    // ──────────────────────────────────────────────────────────────────────────

    /**
     * Compute UGTC blended advantages for a trajectory.
     *
     * A^UGTC_t = u(sβ‚œ)Β·A^slow_t + (1-u(sβ‚œ))Β·A^fast_t
     *
     * @param obsSeq     Sequence of observations (T Γ— obsDim)
     * @param nextObsSeq Sequence of next observations (T Γ— obsDim)
     * @param rewards    Reward sequence (T,)
     * @param dones      Done flags (T,)
     * @param gamma      Discount factor
     * @param train      Whether to update EMA
     * @return           UGTC blended advantages (T,)
     */
    public float[] computeAdvantages(
        float[][] obsSeq, float[][] nextObsSeq,
        float[] rewards, float[] dones,
        float gamma, boolean train
    ) {
        int T = rewards.length;
        float[] gates      = new float[T];
        float[] vFastArr   = new float[T];
        float[] vSlowArr   = new float[T];
        float[] vFastNext  = new float[T];
        float[] vSlowNext  = new float[T];

        for (int t = 0; t < T; t++) {
            GateResult r     = computeGate(obsSeq[t], train);
            GateResult rNext = computeGate(nextObsSeq[t], false);
            gates[t]      = r.gate;
            vFastArr[t]   = r.vFast;
            vSlowArr[t]   = r.vSlow;
            vFastNext[t]  = rNext.vFast;
            vSlowNext[t]  = rNext.vSlow;
        }

        float[] advFast = computeGAE(rewards, vFastArr, vFastNext, dones, gamma, config.lambdaFast);
        float[] advSlow = computeGAE(rewards, vSlowArr, vSlowNext, dones, gamma, config.lambdaSlow);

        float[] advantages = new float[T];
        for (int t = 0; t < T; t++) {
            advantages[t] = gates[t] * advSlow[t] + (1.0f - gates[t]) * advFast[t];
        }
        return advantages;
    }

    // Convenience overload with training=false and gamma=0.99
    public float[] computeAdvantages(float[][] obsSeq, float[][] nextObsSeq,
                                     float[] rewards, float[] dones) {
        return computeAdvantages(obsSeq, nextObsSeq, rewards, dones, 0.99f, false);
    }

    // ──────────────────────────────────────────────────────────────────────────
    // Accessors
    // ──────────────────────────────────────────────────────────────────────────

    public float getSigmaEMA() { return sigmaEMA; }
    public Config getConfig()  { return config; }

    // ──────────────────────────────────────────────────────────────────────────
    // Utility functions
    // ──────────────────────────────────────────────────────────────────────────

    private static float sigmoid(float x) {
        return 1.0f / (1.0f + (float) Math.exp(-x));
    }

    private static float[] applyTanh(float[] x) {
        float[] out = new float[x.length];
        for (int i = 0; i < x.length; i++) out[i] = (float) Math.tanh(x[i]);
        return out;
    }

    // ──────────────────────────────────────────────────────────────────────────
    // Main β€” minimal smoke test
    // ──────────────────────────────────────────────────────────────────────────

    public static void main(String[] args) {
        System.out.println("UGTC Java Reference Implementation");
        System.out.println("Paper: https://doi.org/10.5281/zenodo.19715116");
        System.out.println();

        int obsDim = 17;
        int T      = 32;
        UGTCModule ugtc = new UGTCModule(obsDim);

        // Random trajectory
        Random rng = new Random(0);
        float[][] obs      = new float[T][obsDim];
        float[][] nextObs  = new float[T][obsDim];
        float[]   rewards  = new float[T];
        float[]   dones    = new float[T];

        for (int t = 0; t < T; t++) {
            for (int d = 0; d < obsDim; d++) {
                obs[t][d]     = (float) rng.nextGaussian();
                nextObs[t][d] = (float) rng.nextGaussian();
            }
            rewards[t] = (float) rng.nextGaussian();
            dones[t]   = (t == T - 1) ? 1.0f : 0.0f;
        }

        float[] advantages = ugtc.computeAdvantages(obs, nextObs, rewards, dones, 0.99f, true);

        System.out.printf("obs_dim: %d  T: %d%n", obsDim, T);
        System.out.printf("Advantages: [%.4f, %.4f, %.4f, ...]%n",
            advantages[0], advantages[1], advantages[2]);

        // Gate check
        GateResult gate = ugtc.computeGate(obs[0], false);
        System.out.printf("Gate u(sβ‚€): %.4f  Οƒ_EMA: %.4f%n",
            gate.gate, ugtc.getSigmaEMA());

        System.out.println("\nSmoke test passed.");
    }
}