ugtc / implementations /java /UGTCModule.java
Ekrem-the-second's picture
Initial release: UGTC - Uncertainty-Gated Temporal Credit
d92d8cf verified
Raw
History Blame Contribute Delete
17.7 kB
/**
* UGTC: Uncertainty-Gated Temporal Credit β€” Java Reference Implementation
* =========================================================================
*
* A pure Java reference implementation of the UGTC module.
* No external dependencies. Uses simple float[][] arrays for matrix ops.
*
* This is a reference implementation for readability and portability.
* For production use, consider using a Java deep learning framework
* (DL4J, PyTorch Java API) with proper GPU support.
*
* Paper: https://doi.org/10.5281/zenodo.19715116
*/
package ai.ethosoft.ugtc;
import java.util.Random;
public class UGTCModule {
// ──────────────────────────────────────────────────────────────────────────
// Configuration
// ──────────────────────────────────────────────────────────────────────────
public static class Config {
public int hiddenDim = 64;
public int M = 3;
public float lambdaFast = 0.80f;
public float lambdaSlow = 0.99f;
public float beta = 5.0f;
public float emaMomentum = 0.99f;
public float eps = 1e-8f;
public Config() {}
public Config(int hiddenDim, int M, float lambdaFast, float lambdaSlow,
float beta, float emaMomentum) {
this.hiddenDim = hiddenDim;
this.M = M;
this.lambdaFast = lambdaFast;
this.lambdaSlow = lambdaSlow;
this.beta = beta;
this.emaMomentum = emaMomentum;
}
}
// ──────────────────────────────────────────────────────────────────────────
// Linear layer
// ──────────────────────────────────────────────────────────────────────────
private static class Linear {
final float[][] W; // [outDim][inDim]
final float[] b; // [outDim]
Linear(int inDim, int outDim, Random rng) {
W = new float[outDim][inDim];
b = new float[outDim];
float scale = (float) Math.sqrt(2.0 / inDim);
for (int i = 0; i < outDim; i++)
for (int j = 0; j < inDim; j++)
W[i][j] = rng.nextGaussian() > 0 ? scale : -scale;
}
float[] forward(float[] x) {
int outDim = W.length;
int inDim = x.length;
float[] out = new float[outDim];
for (int i = 0; i < outDim; i++) {
out[i] = b[i];
for (int j = 0; j < inDim; j++)
out[i] += W[i][j] * x[j];
}
return out;
}
}
// ──────────────────────────────────────────────────────────────────────────
// Value network: obs β†’ h β†’ h β†’ scalar
// ──────────────────────────────────────────────────────────────────────────
private static class ValueNetwork {
final Linear fc1, fc2, fc3;
ValueNetwork(int obsDim, int hiddenDim, Random rng) {
fc1 = new Linear(obsDim, hiddenDim, rng);
fc2 = new Linear(hiddenDim, hiddenDim, rng);
fc3 = new Linear(hiddenDim, 1, rng);
}
float forward(float[] obs) {
float[] h1 = applyTanh(fc1.forward(obs));
float[] h2 = applyTanh(fc2.forward(h1));
return fc3.forward(h2)[0];
}
}
// ──────────────────────────────────────────────────────────────────────────
// Ensemble value network
// ──────────────────────────────────────────────────────────────────────────
private static class EnsembleValueNetwork {
final ValueNetwork[] members;
EnsembleValueNetwork(int obsDim, int hiddenDim, int M, Random rng) {
members = new ValueNetwork[M];
for (int i = 0; i < M; i++)
members[i] = new ValueNetwork(obsDim, hiddenDim, rng);
}
/** @return float[] {mean, std} of ensemble predictions */
float[] forward(float[] obs) {
int M = members.length;
float[] vals = new float[M];
float mean = 0.0f;
for (int i = 0; i < M; i++) {
vals[i] = members[i].forward(obs);
mean += vals[i];
}
mean /= M;
float var = 0.0f;
for (float v : vals) var += (v - mean) * (v - mean);
var /= (M > 1 ? M - 1 : 1);
return new float[]{ mean, (float) Math.sqrt(var) };
}
}
// ──────────────────────────────────────────────────────────────────────────
// Gate result
// ──────────────────────────────────────────────────────────────────────────
public static class GateResult {
public final float gate;
public final float vFast;
public final float vSlow;
public final float sigma;
GateResult(float gate, float vFast, float vSlow, float sigma) {
this.gate = gate;
this.vFast = vFast;
this.vSlow = vSlow;
this.sigma = sigma;
}
}
// ──────────────────────────────────────────────────────────────────────────
// Module fields
// ──────────────────────────────────────────────────────────────────────────
private final Config config;
private final ValueNetwork fastCritic;
private final EnsembleValueNetwork slowEnsemble;
private float sigmaEMA;
// ──────────────────────────────────────────────────────────────────────────
// Constructor
// ──────────────────────────────────────────────────────────────────────────
public UGTCModule(int obsDim) {
this(obsDim, new Config());
}
public UGTCModule(int obsDim, Config config) {
this.config = config;
this.sigmaEMA = 1.0f;
Random rng = new Random(42);
this.fastCritic = new ValueNetwork(obsDim, config.hiddenDim, rng);
this.slowEnsemble = new EnsembleValueNetwork(obsDim, config.hiddenDim, config.M, rng);
}
// ──────────────────────────────────────────────────────────────────────────
// Gate computation
// ──────────────────────────────────────────────────────────────────────────
/**
* Compute the uncertainty gate u(s) for a single observation.
*
* u(s) = sigmoid(-Ξ² Β· (ΟƒΜ‚(s) - 1))
* where ΟƒΜ‚(s) = Οƒ(s) / Οƒ_EMA
*
* @param obs Observation vector
* @param train Whether to update EMA statistics
* @return GateResult containing gate, v_fast, v_slow, sigma
*/
public GateResult computeGate(float[] obs, boolean train) {
float vFast = fastCritic.forward(obs);
float[] ensOut = slowEnsemble.forward(obs);
float vSlow = ensOut[0];
float sigma = ensOut[1];
if (train) {
sigmaEMA = config.emaMomentum * sigmaEMA
+ (1.0f - config.emaMomentum) * sigma;
}
float normalizedSigma = sigma / (sigmaEMA + config.eps);
float gate = sigmoid(-config.beta * (normalizedSigma - 1.0f));
return new GateResult(gate, vFast, vSlow, sigma);
}
// ──────────────────────────────────────────────────────────────────────────
// Value estimation
// ──────────────────────────────────────────────────────────────────────────
/**
* Blended value estimate: V^UGTC(s) = u(s)Β·VΜ„_slow(s) + (1-u(s))Β·V_fast(s)
*
* @param obs Observation vector
* @param train Whether to update EMA
* @return Scalar blended value
*/
public float getValueUGTC(float[] obs, boolean train) {
GateResult r = computeGate(obs, train);
return r.gate * r.vSlow + (1.0f - r.gate) * r.vFast;
}
// ──────────────────────────────────────────────────────────────────────────
// GAE computation
// ──────────────────────────────────────────────────────────────────────────
/**
* Standard Generalized Advantage Estimation.
*
* Ξ΄β‚œ = rβ‚œ + Ξ³Β·V(sβ‚œβ‚Šβ‚)Β·(1-dβ‚œ) - V(sβ‚œ)
* Aβ‚œ = Ξ΄β‚œ + γλ·(1-dβ‚œ)Β·Aβ‚œβ‚Šβ‚
*
* @param rewards Reward sequence
* @param values Current-state values
* @param nextVals Next-state values
* @param dones Episode termination flags (1.0 = done)
* @param gamma Discount factor
* @param lam GAE lambda
* @return Array of advantage estimates
*/
public static float[] computeGAE(
float[] rewards, float[] values, float[] nextVals, float[] dones,
float gamma, float lam
) {
int T = rewards.length;
float[] advantages = new float[T];
float gae = 0.0f;
for (int t = T - 1; t >= 0; t--) {
float delta = rewards[t] + gamma * nextVals[t] * (1.0f - dones[t]) - values[t];
gae = delta + gamma * lam * (1.0f - dones[t]) * gae;
advantages[t] = gae;
}
return advantages;
}
// ──────────────────────────────────────────────────────────────────────────
// UGTC advantage computation
// ──────────────────────────────────────────────────────────────────────────
/**
* Compute UGTC blended advantages for a trajectory.
*
* A^UGTC_t = u(sβ‚œ)Β·A^slow_t + (1-u(sβ‚œ))Β·A^fast_t
*
* @param obsSeq Sequence of observations (T Γ— obsDim)
* @param nextObsSeq Sequence of next observations (T Γ— obsDim)
* @param rewards Reward sequence (T,)
* @param dones Done flags (T,)
* @param gamma Discount factor
* @param train Whether to update EMA
* @return UGTC blended advantages (T,)
*/
public float[] computeAdvantages(
float[][] obsSeq, float[][] nextObsSeq,
float[] rewards, float[] dones,
float gamma, boolean train
) {
int T = rewards.length;
float[] gates = new float[T];
float[] vFastArr = new float[T];
float[] vSlowArr = new float[T];
float[] vFastNext = new float[T];
float[] vSlowNext = new float[T];
for (int t = 0; t < T; t++) {
GateResult r = computeGate(obsSeq[t], train);
GateResult rNext = computeGate(nextObsSeq[t], false);
gates[t] = r.gate;
vFastArr[t] = r.vFast;
vSlowArr[t] = r.vSlow;
vFastNext[t] = rNext.vFast;
vSlowNext[t] = rNext.vSlow;
}
float[] advFast = computeGAE(rewards, vFastArr, vFastNext, dones, gamma, config.lambdaFast);
float[] advSlow = computeGAE(rewards, vSlowArr, vSlowNext, dones, gamma, config.lambdaSlow);
float[] advantages = new float[T];
for (int t = 0; t < T; t++) {
advantages[t] = gates[t] * advSlow[t] + (1.0f - gates[t]) * advFast[t];
}
return advantages;
}
// Convenience overload with training=false and gamma=0.99
public float[] computeAdvantages(float[][] obsSeq, float[][] nextObsSeq,
float[] rewards, float[] dones) {
return computeAdvantages(obsSeq, nextObsSeq, rewards, dones, 0.99f, false);
}
// ──────────────────────────────────────────────────────────────────────────
// Accessors
// ──────────────────────────────────────────────────────────────────────────
public float getSigmaEMA() { return sigmaEMA; }
public Config getConfig() { return config; }
// ──────────────────────────────────────────────────────────────────────────
// Utility functions
// ──────────────────────────────────────────────────────────────────────────
private static float sigmoid(float x) {
return 1.0f / (1.0f + (float) Math.exp(-x));
}
private static float[] applyTanh(float[] x) {
float[] out = new float[x.length];
for (int i = 0; i < x.length; i++) out[i] = (float) Math.tanh(x[i]);
return out;
}
// ──────────────────────────────────────────────────────────────────────────
// Main β€” minimal smoke test
// ──────────────────────────────────────────────────────────────────────────
public static void main(String[] args) {
System.out.println("UGTC Java Reference Implementation");
System.out.println("Paper: https://doi.org/10.5281/zenodo.19715116");
System.out.println();
int obsDim = 17;
int T = 32;
UGTCModule ugtc = new UGTCModule(obsDim);
// Random trajectory
Random rng = new Random(0);
float[][] obs = new float[T][obsDim];
float[][] nextObs = new float[T][obsDim];
float[] rewards = new float[T];
float[] dones = new float[T];
for (int t = 0; t < T; t++) {
for (int d = 0; d < obsDim; d++) {
obs[t][d] = (float) rng.nextGaussian();
nextObs[t][d] = (float) rng.nextGaussian();
}
rewards[t] = (float) rng.nextGaussian();
dones[t] = (t == T - 1) ? 1.0f : 0.0f;
}
float[] advantages = ugtc.computeAdvantages(obs, nextObs, rewards, dones, 0.99f, true);
System.out.printf("obs_dim: %d T: %d%n", obsDim, T);
System.out.printf("Advantages: [%.4f, %.4f, %.4f, ...]%n",
advantages[0], advantages[1], advantages[2]);
// Gate check
GateResult gate = ugtc.computeGate(obs[0], false);
System.out.printf("Gate u(sβ‚€): %.4f Οƒ_EMA: %.4f%n",
gate.gate, ugtc.getSigmaEMA());
System.out.println("\nSmoke test passed.");
}
}