| /** | |
| * UGTC: Uncertainty-Gated Temporal Credit β Java Reference Implementation | |
| * ========================================================================= | |
| * | |
| * A pure Java reference implementation of the UGTC module. | |
| * No external dependencies. Uses simple float[][] arrays for matrix ops. | |
| * | |
| * This is a reference implementation for readability and portability. | |
| * For production use, consider using a Java deep learning framework | |
| * (DL4J, PyTorch Java API) with proper GPU support. | |
| * | |
| * Paper: https://doi.org/10.5281/zenodo.19715116 | |
| */ | |
| package ai.ethosoft.ugtc; | |
| import java.util.Random; | |
| public class UGTCModule { | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Configuration | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| public static class Config { | |
| public int hiddenDim = 64; | |
| public int M = 3; | |
| public float lambdaFast = 0.80f; | |
| public float lambdaSlow = 0.99f; | |
| public float beta = 5.0f; | |
| public float emaMomentum = 0.99f; | |
| public float eps = 1e-8f; | |
| public Config() {} | |
| public Config(int hiddenDim, int M, float lambdaFast, float lambdaSlow, | |
| float beta, float emaMomentum) { | |
| this.hiddenDim = hiddenDim; | |
| this.M = M; | |
| this.lambdaFast = lambdaFast; | |
| this.lambdaSlow = lambdaSlow; | |
| this.beta = beta; | |
| this.emaMomentum = emaMomentum; | |
| } | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Linear layer | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| private static class Linear { | |
| final float[][] W; // [outDim][inDim] | |
| final float[] b; // [outDim] | |
| Linear(int inDim, int outDim, Random rng) { | |
| W = new float[outDim][inDim]; | |
| b = new float[outDim]; | |
| float scale = (float) Math.sqrt(2.0 / inDim); | |
| for (int i = 0; i < outDim; i++) | |
| for (int j = 0; j < inDim; j++) | |
| W[i][j] = rng.nextGaussian() > 0 ? scale : -scale; | |
| } | |
| float[] forward(float[] x) { | |
| int outDim = W.length; | |
| int inDim = x.length; | |
| float[] out = new float[outDim]; | |
| for (int i = 0; i < outDim; i++) { | |
| out[i] = b[i]; | |
| for (int j = 0; j < inDim; j++) | |
| out[i] += W[i][j] * x[j]; | |
| } | |
| return out; | |
| } | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Value network: obs β h β h β scalar | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| private static class ValueNetwork { | |
| final Linear fc1, fc2, fc3; | |
| ValueNetwork(int obsDim, int hiddenDim, Random rng) { | |
| fc1 = new Linear(obsDim, hiddenDim, rng); | |
| fc2 = new Linear(hiddenDim, hiddenDim, rng); | |
| fc3 = new Linear(hiddenDim, 1, rng); | |
| } | |
| float forward(float[] obs) { | |
| float[] h1 = applyTanh(fc1.forward(obs)); | |
| float[] h2 = applyTanh(fc2.forward(h1)); | |
| return fc3.forward(h2)[0]; | |
| } | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Ensemble value network | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| private static class EnsembleValueNetwork { | |
| final ValueNetwork[] members; | |
| EnsembleValueNetwork(int obsDim, int hiddenDim, int M, Random rng) { | |
| members = new ValueNetwork[M]; | |
| for (int i = 0; i < M; i++) | |
| members[i] = new ValueNetwork(obsDim, hiddenDim, rng); | |
| } | |
| /** @return float[] {mean, std} of ensemble predictions */ | |
| float[] forward(float[] obs) { | |
| int M = members.length; | |
| float[] vals = new float[M]; | |
| float mean = 0.0f; | |
| for (int i = 0; i < M; i++) { | |
| vals[i] = members[i].forward(obs); | |
| mean += vals[i]; | |
| } | |
| mean /= M; | |
| float var = 0.0f; | |
| for (float v : vals) var += (v - mean) * (v - mean); | |
| var /= (M > 1 ? M - 1 : 1); | |
| return new float[]{ mean, (float) Math.sqrt(var) }; | |
| } | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Gate result | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| public static class GateResult { | |
| public final float gate; | |
| public final float vFast; | |
| public final float vSlow; | |
| public final float sigma; | |
| GateResult(float gate, float vFast, float vSlow, float sigma) { | |
| this.gate = gate; | |
| this.vFast = vFast; | |
| this.vSlow = vSlow; | |
| this.sigma = sigma; | |
| } | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Module fields | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| private final Config config; | |
| private final ValueNetwork fastCritic; | |
| private final EnsembleValueNetwork slowEnsemble; | |
| private float sigmaEMA; | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Constructor | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| public UGTCModule(int obsDim) { | |
| this(obsDim, new Config()); | |
| } | |
| public UGTCModule(int obsDim, Config config) { | |
| this.config = config; | |
| this.sigmaEMA = 1.0f; | |
| Random rng = new Random(42); | |
| this.fastCritic = new ValueNetwork(obsDim, config.hiddenDim, rng); | |
| this.slowEnsemble = new EnsembleValueNetwork(obsDim, config.hiddenDim, config.M, rng); | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Gate computation | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| /** | |
| * Compute the uncertainty gate u(s) for a single observation. | |
| * | |
| * u(s) = sigmoid(-Ξ² Β· (ΟΜ(s) - 1)) | |
| * where ΟΜ(s) = Ο(s) / Ο_EMA | |
| * | |
| * @param obs Observation vector | |
| * @param train Whether to update EMA statistics | |
| * @return GateResult containing gate, v_fast, v_slow, sigma | |
| */ | |
| public GateResult computeGate(float[] obs, boolean train) { | |
| float vFast = fastCritic.forward(obs); | |
| float[] ensOut = slowEnsemble.forward(obs); | |
| float vSlow = ensOut[0]; | |
| float sigma = ensOut[1]; | |
| if (train) { | |
| sigmaEMA = config.emaMomentum * sigmaEMA | |
| + (1.0f - config.emaMomentum) * sigma; | |
| } | |
| float normalizedSigma = sigma / (sigmaEMA + config.eps); | |
| float gate = sigmoid(-config.beta * (normalizedSigma - 1.0f)); | |
| return new GateResult(gate, vFast, vSlow, sigma); | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Value estimation | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| /** | |
| * Blended value estimate: V^UGTC(s) = u(s)Β·VΜ_slow(s) + (1-u(s))Β·V_fast(s) | |
| * | |
| * @param obs Observation vector | |
| * @param train Whether to update EMA | |
| * @return Scalar blended value | |
| */ | |
| public float getValueUGTC(float[] obs, boolean train) { | |
| GateResult r = computeGate(obs, train); | |
| return r.gate * r.vSlow + (1.0f - r.gate) * r.vFast; | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // GAE computation | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| /** | |
| * Standard Generalized Advantage Estimation. | |
| * | |
| * Ξ΄β = rβ + Ξ³Β·V(sβββ)Β·(1-dβ) - V(sβ) | |
| * Aβ = Ξ΄β + γλ·(1-dβ)Β·Aβββ | |
| * | |
| * @param rewards Reward sequence | |
| * @param values Current-state values | |
| * @param nextVals Next-state values | |
| * @param dones Episode termination flags (1.0 = done) | |
| * @param gamma Discount factor | |
| * @param lam GAE lambda | |
| * @return Array of advantage estimates | |
| */ | |
| public static float[] computeGAE( | |
| float[] rewards, float[] values, float[] nextVals, float[] dones, | |
| float gamma, float lam | |
| ) { | |
| int T = rewards.length; | |
| float[] advantages = new float[T]; | |
| float gae = 0.0f; | |
| for (int t = T - 1; t >= 0; t--) { | |
| float delta = rewards[t] + gamma * nextVals[t] * (1.0f - dones[t]) - values[t]; | |
| gae = delta + gamma * lam * (1.0f - dones[t]) * gae; | |
| advantages[t] = gae; | |
| } | |
| return advantages; | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // UGTC advantage computation | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| /** | |
| * Compute UGTC blended advantages for a trajectory. | |
| * | |
| * A^UGTC_t = u(sβ)Β·A^slow_t + (1-u(sβ))Β·A^fast_t | |
| * | |
| * @param obsSeq Sequence of observations (T Γ obsDim) | |
| * @param nextObsSeq Sequence of next observations (T Γ obsDim) | |
| * @param rewards Reward sequence (T,) | |
| * @param dones Done flags (T,) | |
| * @param gamma Discount factor | |
| * @param train Whether to update EMA | |
| * @return UGTC blended advantages (T,) | |
| */ | |
| public float[] computeAdvantages( | |
| float[][] obsSeq, float[][] nextObsSeq, | |
| float[] rewards, float[] dones, | |
| float gamma, boolean train | |
| ) { | |
| int T = rewards.length; | |
| float[] gates = new float[T]; | |
| float[] vFastArr = new float[T]; | |
| float[] vSlowArr = new float[T]; | |
| float[] vFastNext = new float[T]; | |
| float[] vSlowNext = new float[T]; | |
| for (int t = 0; t < T; t++) { | |
| GateResult r = computeGate(obsSeq[t], train); | |
| GateResult rNext = computeGate(nextObsSeq[t], false); | |
| gates[t] = r.gate; | |
| vFastArr[t] = r.vFast; | |
| vSlowArr[t] = r.vSlow; | |
| vFastNext[t] = rNext.vFast; | |
| vSlowNext[t] = rNext.vSlow; | |
| } | |
| float[] advFast = computeGAE(rewards, vFastArr, vFastNext, dones, gamma, config.lambdaFast); | |
| float[] advSlow = computeGAE(rewards, vSlowArr, vSlowNext, dones, gamma, config.lambdaSlow); | |
| float[] advantages = new float[T]; | |
| for (int t = 0; t < T; t++) { | |
| advantages[t] = gates[t] * advSlow[t] + (1.0f - gates[t]) * advFast[t]; | |
| } | |
| return advantages; | |
| } | |
| // Convenience overload with training=false and gamma=0.99 | |
| public float[] computeAdvantages(float[][] obsSeq, float[][] nextObsSeq, | |
| float[] rewards, float[] dones) { | |
| return computeAdvantages(obsSeq, nextObsSeq, rewards, dones, 0.99f, false); | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Accessors | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| public float getSigmaEMA() { return sigmaEMA; } | |
| public Config getConfig() { return config; } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Utility functions | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| private static float sigmoid(float x) { | |
| return 1.0f / (1.0f + (float) Math.exp(-x)); | |
| } | |
| private static float[] applyTanh(float[] x) { | |
| float[] out = new float[x.length]; | |
| for (int i = 0; i < x.length; i++) out[i] = (float) Math.tanh(x[i]); | |
| return out; | |
| } | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| // Main β minimal smoke test | |
| // ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| public static void main(String[] args) { | |
| System.out.println("UGTC Java Reference Implementation"); | |
| System.out.println("Paper: https://doi.org/10.5281/zenodo.19715116"); | |
| System.out.println(); | |
| int obsDim = 17; | |
| int T = 32; | |
| UGTCModule ugtc = new UGTCModule(obsDim); | |
| // Random trajectory | |
| Random rng = new Random(0); | |
| float[][] obs = new float[T][obsDim]; | |
| float[][] nextObs = new float[T][obsDim]; | |
| float[] rewards = new float[T]; | |
| float[] dones = new float[T]; | |
| for (int t = 0; t < T; t++) { | |
| for (int d = 0; d < obsDim; d++) { | |
| obs[t][d] = (float) rng.nextGaussian(); | |
| nextObs[t][d] = (float) rng.nextGaussian(); | |
| } | |
| rewards[t] = (float) rng.nextGaussian(); | |
| dones[t] = (t == T - 1) ? 1.0f : 0.0f; | |
| } | |
| float[] advantages = ugtc.computeAdvantages(obs, nextObs, rewards, dones, 0.99f, true); | |
| System.out.printf("obs_dim: %d T: %d%n", obsDim, T); | |
| System.out.printf("Advantages: [%.4f, %.4f, %.4f, ...]%n", | |
| advantages[0], advantages[1], advantages[2]); | |
| // Gate check | |
| GateResult gate = ugtc.computeGate(obs[0], false); | |
| System.out.printf("Gate u(sβ): %.4f Ο_EMA: %.4f%n", | |
| gate.gate, ugtc.getSigmaEMA()); | |
| System.out.println("\nSmoke test passed."); | |
| } | |
| } | |