/** * UGTC: Uncertainty-Gated Temporal Credit — Java Reference Implementation * ========================================================================= * * A pure Java reference implementation of the UGTC module. * No external dependencies. Uses simple float[][] arrays for matrix ops. * * This is a reference implementation for readability and portability. * For production use, consider using a Java deep learning framework * (DL4J, PyTorch Java API) with proper GPU support. * * Paper: https://doi.org/10.5281/zenodo.19715116 */ package ai.ethosoft.ugtc; import java.util.Random; public class UGTCModule { // ────────────────────────────────────────────────────────────────────────── // Configuration // ────────────────────────────────────────────────────────────────────────── public static class Config { public int hiddenDim = 64; public int M = 3; public float lambdaFast = 0.80f; public float lambdaSlow = 0.99f; public float beta = 5.0f; public float emaMomentum = 0.99f; public float eps = 1e-8f; public Config() {} public Config(int hiddenDim, int M, float lambdaFast, float lambdaSlow, float beta, float emaMomentum) { this.hiddenDim = hiddenDim; this.M = M; this.lambdaFast = lambdaFast; this.lambdaSlow = lambdaSlow; this.beta = beta; this.emaMomentum = emaMomentum; } } // ────────────────────────────────────────────────────────────────────────── // Linear layer // ────────────────────────────────────────────────────────────────────────── private static class Linear { final float[][] W; // [outDim][inDim] final float[] b; // [outDim] Linear(int inDim, int outDim, Random rng) { W = new float[outDim][inDim]; b = new float[outDim]; float scale = (float) Math.sqrt(2.0 / inDim); for (int i = 0; i < outDim; i++) for (int j = 0; j < inDim; j++) W[i][j] = rng.nextGaussian() > 0 ? scale : -scale; } float[] forward(float[] x) { int outDim = W.length; int inDim = x.length; float[] out = new float[outDim]; for (int i = 0; i < outDim; i++) { out[i] = b[i]; for (int j = 0; j < inDim; j++) out[i] += W[i][j] * x[j]; } return out; } } // ────────────────────────────────────────────────────────────────────────── // Value network: obs → h → h → scalar // ────────────────────────────────────────────────────────────────────────── private static class ValueNetwork { final Linear fc1, fc2, fc3; ValueNetwork(int obsDim, int hiddenDim, Random rng) { fc1 = new Linear(obsDim, hiddenDim, rng); fc2 = new Linear(hiddenDim, hiddenDim, rng); fc3 = new Linear(hiddenDim, 1, rng); } float forward(float[] obs) { float[] h1 = applyTanh(fc1.forward(obs)); float[] h2 = applyTanh(fc2.forward(h1)); return fc3.forward(h2)[0]; } } // ────────────────────────────────────────────────────────────────────────── // Ensemble value network // ────────────────────────────────────────────────────────────────────────── private static class EnsembleValueNetwork { final ValueNetwork[] members; EnsembleValueNetwork(int obsDim, int hiddenDim, int M, Random rng) { members = new ValueNetwork[M]; for (int i = 0; i < M; i++) members[i] = new ValueNetwork(obsDim, hiddenDim, rng); } /** @return float[] {mean, std} of ensemble predictions */ float[] forward(float[] obs) { int M = members.length; float[] vals = new float[M]; float mean = 0.0f; for (int i = 0; i < M; i++) { vals[i] = members[i].forward(obs); mean += vals[i]; } mean /= M; float var = 0.0f; for (float v : vals) var += (v - mean) * (v - mean); var /= (M > 1 ? M - 1 : 1); return new float[]{ mean, (float) Math.sqrt(var) }; } } // ────────────────────────────────────────────────────────────────────────── // Gate result // ────────────────────────────────────────────────────────────────────────── public static class GateResult { public final float gate; public final float vFast; public final float vSlow; public final float sigma; GateResult(float gate, float vFast, float vSlow, float sigma) { this.gate = gate; this.vFast = vFast; this.vSlow = vSlow; this.sigma = sigma; } } // ────────────────────────────────────────────────────────────────────────── // Module fields // ────────────────────────────────────────────────────────────────────────── private final Config config; private final ValueNetwork fastCritic; private final EnsembleValueNetwork slowEnsemble; private float sigmaEMA; // ────────────────────────────────────────────────────────────────────────── // Constructor // ────────────────────────────────────────────────────────────────────────── public UGTCModule(int obsDim) { this(obsDim, new Config()); } public UGTCModule(int obsDim, Config config) { this.config = config; this.sigmaEMA = 1.0f; Random rng = new Random(42); this.fastCritic = new ValueNetwork(obsDim, config.hiddenDim, rng); this.slowEnsemble = new EnsembleValueNetwork(obsDim, config.hiddenDim, config.M, rng); } // ────────────────────────────────────────────────────────────────────────── // Gate computation // ────────────────────────────────────────────────────────────────────────── /** * Compute the uncertainty gate u(s) for a single observation. * * u(s) = sigmoid(-β · (σ̂(s) - 1)) * where σ̂(s) = σ(s) / σ_EMA * * @param obs Observation vector * @param train Whether to update EMA statistics * @return GateResult containing gate, v_fast, v_slow, sigma */ public GateResult computeGate(float[] obs, boolean train) { float vFast = fastCritic.forward(obs); float[] ensOut = slowEnsemble.forward(obs); float vSlow = ensOut[0]; float sigma = ensOut[1]; if (train) { sigmaEMA = config.emaMomentum * sigmaEMA + (1.0f - config.emaMomentum) * sigma; } float normalizedSigma = sigma / (sigmaEMA + config.eps); float gate = sigmoid(-config.beta * (normalizedSigma - 1.0f)); return new GateResult(gate, vFast, vSlow, sigma); } // ────────────────────────────────────────────────────────────────────────── // Value estimation // ────────────────────────────────────────────────────────────────────────── /** * Blended value estimate: V^UGTC(s) = u(s)·V̄_slow(s) + (1-u(s))·V_fast(s) * * @param obs Observation vector * @param train Whether to update EMA * @return Scalar blended value */ public float getValueUGTC(float[] obs, boolean train) { GateResult r = computeGate(obs, train); return r.gate * r.vSlow + (1.0f - r.gate) * r.vFast; } // ────────────────────────────────────────────────────────────────────────── // GAE computation // ────────────────────────────────────────────────────────────────────────── /** * Standard Generalized Advantage Estimation. * * δₜ = rₜ + γ·V(sₜ₊₁)·(1-dₜ) - V(sₜ) * Aₜ = δₜ + γλ·(1-dₜ)·Aₜ₊₁ * * @param rewards Reward sequence * @param values Current-state values * @param nextVals Next-state values * @param dones Episode termination flags (1.0 = done) * @param gamma Discount factor * @param lam GAE lambda * @return Array of advantage estimates */ public static float[] computeGAE( float[] rewards, float[] values, float[] nextVals, float[] dones, float gamma, float lam ) { int T = rewards.length; float[] advantages = new float[T]; float gae = 0.0f; for (int t = T - 1; t >= 0; t--) { float delta = rewards[t] + gamma * nextVals[t] * (1.0f - dones[t]) - values[t]; gae = delta + gamma * lam * (1.0f - dones[t]) * gae; advantages[t] = gae; } return advantages; } // ────────────────────────────────────────────────────────────────────────── // UGTC advantage computation // ────────────────────────────────────────────────────────────────────────── /** * Compute UGTC blended advantages for a trajectory. * * A^UGTC_t = u(sₜ)·A^slow_t + (1-u(sₜ))·A^fast_t * * @param obsSeq Sequence of observations (T × obsDim) * @param nextObsSeq Sequence of next observations (T × obsDim) * @param rewards Reward sequence (T,) * @param dones Done flags (T,) * @param gamma Discount factor * @param train Whether to update EMA * @return UGTC blended advantages (T,) */ public float[] computeAdvantages( float[][] obsSeq, float[][] nextObsSeq, float[] rewards, float[] dones, float gamma, boolean train ) { int T = rewards.length; float[] gates = new float[T]; float[] vFastArr = new float[T]; float[] vSlowArr = new float[T]; float[] vFastNext = new float[T]; float[] vSlowNext = new float[T]; for (int t = 0; t < T; t++) { GateResult r = computeGate(obsSeq[t], train); GateResult rNext = computeGate(nextObsSeq[t], false); gates[t] = r.gate; vFastArr[t] = r.vFast; vSlowArr[t] = r.vSlow; vFastNext[t] = rNext.vFast; vSlowNext[t] = rNext.vSlow; } float[] advFast = computeGAE(rewards, vFastArr, vFastNext, dones, gamma, config.lambdaFast); float[] advSlow = computeGAE(rewards, vSlowArr, vSlowNext, dones, gamma, config.lambdaSlow); float[] advantages = new float[T]; for (int t = 0; t < T; t++) { advantages[t] = gates[t] * advSlow[t] + (1.0f - gates[t]) * advFast[t]; } return advantages; } // Convenience overload with training=false and gamma=0.99 public float[] computeAdvantages(float[][] obsSeq, float[][] nextObsSeq, float[] rewards, float[] dones) { return computeAdvantages(obsSeq, nextObsSeq, rewards, dones, 0.99f, false); } // ────────────────────────────────────────────────────────────────────────── // Accessors // ────────────────────────────────────────────────────────────────────────── public float getSigmaEMA() { return sigmaEMA; } public Config getConfig() { return config; } // ────────────────────────────────────────────────────────────────────────── // Utility functions // ────────────────────────────────────────────────────────────────────────── private static float sigmoid(float x) { return 1.0f / (1.0f + (float) Math.exp(-x)); } private static float[] applyTanh(float[] x) { float[] out = new float[x.length]; for (int i = 0; i < x.length; i++) out[i] = (float) Math.tanh(x[i]); return out; } // ────────────────────────────────────────────────────────────────────────── // Main — minimal smoke test // ────────────────────────────────────────────────────────────────────────── public static void main(String[] args) { System.out.println("UGTC Java Reference Implementation"); System.out.println("Paper: https://doi.org/10.5281/zenodo.19715116"); System.out.println(); int obsDim = 17; int T = 32; UGTCModule ugtc = new UGTCModule(obsDim); // Random trajectory Random rng = new Random(0); float[][] obs = new float[T][obsDim]; float[][] nextObs = new float[T][obsDim]; float[] rewards = new float[T]; float[] dones = new float[T]; for (int t = 0; t < T; t++) { for (int d = 0; d < obsDim; d++) { obs[t][d] = (float) rng.nextGaussian(); nextObs[t][d] = (float) rng.nextGaussian(); } rewards[t] = (float) rng.nextGaussian(); dones[t] = (t == T - 1) ? 1.0f : 0.0f; } float[] advantages = ugtc.computeAdvantages(obs, nextObs, rewards, dones, 0.99f, true); System.out.printf("obs_dim: %d T: %d%n", obsDim, T); System.out.printf("Advantages: [%.4f, %.4f, %.4f, ...]%n", advantages[0], advantages[1], advantages[2]); // Gate check GateResult gate = ugtc.computeGate(obs[0], false); System.out.printf("Gate u(s₀): %.4f σ_EMA: %.4f%n", gate.gate, ugtc.getSigmaEMA()); System.out.println("\nSmoke test passed."); } }