ugtc / implementations /cpp /ugtc.hpp
Ekrem-the-second's picture
Initial release: UGTC - Uncertainty-Gated Temporal Credit
d92d8cf verified
Raw
History Blame Contribute Delete
13.2 kB
/**
* UGTC: Uncertainty-Gated Temporal Credit — C++ Header-Only Reference Implementation
* ====================================================================================
*
* A minimal, dependency-free reference implementation of the UGTC module.
* Uses Eigen3 for matrix operations. No RL framework dependency.
*
* Requirements:
* - C++17 or later
* - Eigen3 (https://eigen.tuxfamily.org/)
*
* Usage:
* #include "ugtc.hpp"
*
* UGTC::Config cfg;
* UGTC::Module ugtc(obs_dim, cfg);
* auto advantages = ugtc.computeAdvantages(obs, next_obs, rewards, dones, gamma);
*
* Paper: https://doi.org/10.5281/zenodo.19715116
*/
#pragma once
#include <vector>
#include <cmath>
#include <numeric>
#include <random>
#include <cassert>
#include <algorithm>
#include <Eigen/Dense>
namespace UGTC {
using Matrix = Eigen::MatrixXf;
using Vector = Eigen::VectorXf;
// ──────────────────────────────────────────────────────────────────────────────
// Configuration
// ──────────────────────────────────────────────────────────────────────────────
struct Config {
int hidden_dim = 64; ///< Hidden layer width
int M = 3; ///< Ensemble size (slow critic)
float lambda_fast = 0.80f; ///< GAE lambda for fast critic
float lambda_slow = 0.99f; ///< GAE lambda for slow ensemble
float beta = 5.0f; ///< Gate temperature
float ema_momentum = 0.99f; ///< EMA momentum for uncertainty normalization
float eps = 1e-8f; ///< Numerical stability epsilon
};
// ──────────────────────────────────────────────────────────────────────────────
// Activation functions
// ──────────────────────────────────────────────────────────────────────────────
inline float sigmoid(float x) {
return 1.0f / (1.0f + std::exp(-x));
}
inline float tanh_activation(float x) {
return std::tanh(x);
}
inline Vector tanh_vec(const Vector& x) {
return x.unaryExpr([](float v) { return std::tanh(v); });
}
// ──────────────────────────────────────────────────────────────────────────────
// Linear layer (weight matrix + bias vector)
// ──────────────────────────────────────────────────────────────────────────────
struct Linear {
Matrix W; ///< (out_dim, in_dim)
Vector b; ///< (out_dim,)
Linear() = default;
Linear(int in_dim, int out_dim, std::mt19937& rng) {
W = Matrix::Random(out_dim, in_dim);
b = Vector::Zero(out_dim);
// Orthogonal-ish initialization via scaled random
float scale = std::sqrt(2.0f / in_dim);
W *= scale;
}
Vector forward(const Vector& x) const {
return W * x + b;
}
};
// ──────────────────────────────────────────────────────────────────────────────
// Value network: obs → hidden → hidden → scalar
// Architecture: Linear → Tanh → Linear → Tanh → Linear
// ──────────────────────────────────────────────────────────────────────────────
struct ValueNetwork {
Linear fc1, fc2, fc3;
ValueNetwork() = default;
ValueNetwork(int obs_dim, int hidden_dim, std::mt19937& rng)
: fc1(obs_dim, hidden_dim, rng)
, fc2(hidden_dim, hidden_dim, rng)
, fc3(hidden_dim, 1, rng)
{}
float forward(const Vector& obs) const {
Vector h1 = tanh_vec(fc1.forward(obs));
Vector h2 = tanh_vec(fc2.forward(h1));
return fc3.forward(h2)(0);
}
};
// ──────────────────────────────────────────────────────────────────────────────
// Ensemble value network: M independent ValueNetworks
// ──────────────────────────────────────────────────────────────────────────────
struct EnsembleValueNetwork {
std::vector<ValueNetwork> members;
int M;
EnsembleValueNetwork() = default;
EnsembleValueNetwork(int obs_dim, int hidden_dim, int M, std::mt19937& rng)
: M(M)
{
members.reserve(M);
for (int i = 0; i < M; ++i) {
members.emplace_back(obs_dim, hidden_dim, rng);
}
}
/// Returns (mean, std) of ensemble predictions for a single observation.
std::pair<float, float> forward(const Vector& obs) const {
std::vector<float> vals;
vals.reserve(M);
for (auto& m : members) vals.push_back(m.forward(obs));
float mean = std::accumulate(vals.begin(), vals.end(), 0.0f) / M;
float var = 0.0f;
for (float v : vals) var += (v - mean) * (v - mean);
var /= (M > 1 ? M - 1 : 1);
return { mean, std::sqrt(var) };
}
};
// ──────────────────────────────────────────────────────────────────────────────
// Gate statistics output
// ──────────────────────────────────────────────────────────────────────────────
struct GateResult {
float gate; ///< u(s) ∈ [0, 1]
float v_fast; ///< Fast critic value
float v_slow; ///< Slow ensemble mean value
float sigma; ///< Ensemble disagreement (std)
};
// ──────────────────────────────────────────────────────────────────────────────
// UGTC Module
// ──────────────────────────────────────────────────────────────────────────────
class Module {
public:
Module(int obs_dim, const Config& cfg = Config{})
: cfg_(cfg)
, sigma_ema_(1.0f)
{
std::mt19937 rng(42);
fast_critic_ = ValueNetwork(obs_dim, cfg.hidden_dim, rng);
slow_ensemble_ = EnsembleValueNetwork(obs_dim, cfg.hidden_dim, cfg.M, rng);
}
// ── Gate computation ──────────────────────────────────────────────────────
/**
* Compute the uncertainty gate u(s) for a single observation.
*
* Steps:
* 1. Evaluate fast critic: v_fast = V_fast(s)
* 2. Evaluate slow ensemble: (v̄_slow, σ) = ensemble(s)
* 3. EMA-normalize: σ̂ = σ / σ_EMA
* 4. Sigmoid gate: u(s) = sigmoid(-β · (σ̂ - 1))
*
* @param obs Observation vector (obs_dim,)
* @param train Whether to update EMA (true during training)
* @return GateResult with gate, v_fast, v_slow, sigma
*/
GateResult computeGate(const Vector& obs, bool train = false) {
float v_fast = fast_critic_.forward(obs);
auto [v_slow, sigma] = slow_ensemble_.forward(obs);
if (train) {
sigma_ema_ = cfg_.ema_momentum * sigma_ema_
+ (1.0f - cfg_.ema_momentum) * sigma;
}
float normalized_sigma = sigma / (sigma_ema_ + cfg_.eps);
float gate = sigmoid(-cfg_.beta * (normalized_sigma - 1.0f));
return { gate, v_fast, v_slow, sigma };
}
// ── Value estimation ──────────────────────────────────────────────────────
/**
* Blended value estimate V^UGTC(s) = u(s)·V̄_slow(s) + (1-u(s))·V_fast(s)
*/
float getValueUGTC(const Vector& obs, bool train = false) {
auto r = computeGate(obs, train);
return r.gate * r.v_slow + (1.0f - r.gate) * r.v_fast;
}
// ── GAE computation ───────────────────────────────────────────────────────
/**
* Standard Generalized Advantage Estimation.
*
* δₜ = rₜ + γ·V(sₜ₊₁)·(1-dₜ) - V(sₜ)
* Aₜ = δₜ + γλ·(1-dₜ)·Aₜ₊₁
*
* @param rewards (T,) reward sequence
* @param values (T,) current-state values
* @param next_vals (T,) next-state values
* @param dones (T,) episode termination flags
* @param gamma discount factor
* @param lam GAE lambda
* @return (T,) advantage estimates
*/
static std::vector<float> computeGAE(
const std::vector<float>& rewards,
const std::vector<float>& values,
const std::vector<float>& next_vals,
const std::vector<float>& dones,
float gamma,
float lam
) {
int T = static_cast<int>(rewards.size());
std::vector<float> advantages(T, 0.0f);
float gae = 0.0f;
for (int t = T - 1; t >= 0; --t) {
float delta = rewards[t] + gamma * next_vals[t] * (1.0f - dones[t]) - values[t];
gae = delta + gamma * lam * (1.0f - dones[t]) * gae;
advantages[t] = gae;
}
return advantages;
}
// ── UGTC advantage ────────────────────────────────────────────────────────
/**
* Compute UGTC blended advantages for a trajectory.
*
* A^UGTC_t = u(sₜ)·A^slow_t + (1-u(sₜ))·A^fast_t
*
* @param obs_seq Sequence of observations (T × obs_dim)
* @param next_obs_seq Sequence of next observations (T × obs_dim)
* @param rewards (T,) rewards
* @param dones (T,) done flags
* @param gamma Discount factor
* @param train Whether to update EMA
* @return (T,) UGTC blended advantages
*/
std::vector<float> computeAdvantages(
const std::vector<Vector>& obs_seq,
const std::vector<Vector>& next_obs_seq,
const std::vector<float>& rewards,
const std::vector<float>& dones,
float gamma = 0.99f,
bool train = false
) {
int T = static_cast<int>(obs_seq.size());
assert(T == static_cast<int>(rewards.size()));
std::vector<float> gates(T), v_fast_arr(T), v_slow_arr(T);
std::vector<float> v_fast_next(T), v_slow_next(T);
for (int t = 0; t < T; ++t) {
auto r = computeGate(obs_seq[t], train);
auto r_next = computeGate(next_obs_seq[t], false);
gates[t] = r.gate;
v_fast_arr[t] = r.v_fast;
v_slow_arr[t] = r.v_slow;
v_fast_next[t] = r_next.v_fast;
v_slow_next[t] = r_next.v_slow;
}
auto adv_fast = computeGAE(rewards, v_fast_arr, v_fast_next, dones, gamma, cfg_.lambda_fast);
auto adv_slow = computeGAE(rewards, v_slow_arr, v_slow_next, dones, gamma, cfg_.lambda_slow);
std::vector<float> advantages(T);
for (int t = 0; t < T; ++t) {
advantages[t] = gates[t] * adv_slow[t] + (1.0f - gates[t]) * adv_fast[t];
}
return advantages;
}
// ── Accessors ─────────────────────────────────────────────────────────────
float getSigmaEMA() const { return sigma_ema_; }
const Config& getConfig() const { return cfg_; }
private:
Config cfg_;
ValueNetwork fast_critic_;
EnsembleValueNetwork slow_ensemble_;
float sigma_ema_;
};
} // namespace UGTC