File size: 13,180 Bytes
d92d8cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 | /**
* UGTC: Uncertainty-Gated Temporal Credit β C++ Header-Only Reference Implementation
* ====================================================================================
*
* A minimal, dependency-free reference implementation of the UGTC module.
* Uses Eigen3 for matrix operations. No RL framework dependency.
*
* Requirements:
* - C++17 or later
* - Eigen3 (https://eigen.tuxfamily.org/)
*
* Usage:
* #include "ugtc.hpp"
*
* UGTC::Config cfg;
* UGTC::Module ugtc(obs_dim, cfg);
* auto advantages = ugtc.computeAdvantages(obs, next_obs, rewards, dones, gamma);
*
* Paper: https://doi.org/10.5281/zenodo.19715116
*/
#pragma once
#include <vector>
#include <cmath>
#include <numeric>
#include <random>
#include <cassert>
#include <algorithm>
#include <Eigen/Dense>
namespace UGTC {
using Matrix = Eigen::MatrixXf;
using Vector = Eigen::VectorXf;
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Configuration
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
struct Config {
int hidden_dim = 64; ///< Hidden layer width
int M = 3; ///< Ensemble size (slow critic)
float lambda_fast = 0.80f; ///< GAE lambda for fast critic
float lambda_slow = 0.99f; ///< GAE lambda for slow ensemble
float beta = 5.0f; ///< Gate temperature
float ema_momentum = 0.99f; ///< EMA momentum for uncertainty normalization
float eps = 1e-8f; ///< Numerical stability epsilon
};
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Activation functions
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
inline float sigmoid(float x) {
return 1.0f / (1.0f + std::exp(-x));
}
inline float tanh_activation(float x) {
return std::tanh(x);
}
inline Vector tanh_vec(const Vector& x) {
return x.unaryExpr([](float v) { return std::tanh(v); });
}
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Linear layer (weight matrix + bias vector)
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
struct Linear {
Matrix W; ///< (out_dim, in_dim)
Vector b; ///< (out_dim,)
Linear() = default;
Linear(int in_dim, int out_dim, std::mt19937& rng) {
W = Matrix::Random(out_dim, in_dim);
b = Vector::Zero(out_dim);
// Orthogonal-ish initialization via scaled random
float scale = std::sqrt(2.0f / in_dim);
W *= scale;
}
Vector forward(const Vector& x) const {
return W * x + b;
}
};
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Value network: obs β hidden β hidden β scalar
// Architecture: Linear β Tanh β Linear β Tanh β Linear
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
struct ValueNetwork {
Linear fc1, fc2, fc3;
ValueNetwork() = default;
ValueNetwork(int obs_dim, int hidden_dim, std::mt19937& rng)
: fc1(obs_dim, hidden_dim, rng)
, fc2(hidden_dim, hidden_dim, rng)
, fc3(hidden_dim, 1, rng)
{}
float forward(const Vector& obs) const {
Vector h1 = tanh_vec(fc1.forward(obs));
Vector h2 = tanh_vec(fc2.forward(h1));
return fc3.forward(h2)(0);
}
};
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Ensemble value network: M independent ValueNetworks
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
struct EnsembleValueNetwork {
std::vector<ValueNetwork> members;
int M;
EnsembleValueNetwork() = default;
EnsembleValueNetwork(int obs_dim, int hidden_dim, int M, std::mt19937& rng)
: M(M)
{
members.reserve(M);
for (int i = 0; i < M; ++i) {
members.emplace_back(obs_dim, hidden_dim, rng);
}
}
/// Returns (mean, std) of ensemble predictions for a single observation.
std::pair<float, float> forward(const Vector& obs) const {
std::vector<float> vals;
vals.reserve(M);
for (auto& m : members) vals.push_back(m.forward(obs));
float mean = std::accumulate(vals.begin(), vals.end(), 0.0f) / M;
float var = 0.0f;
for (float v : vals) var += (v - mean) * (v - mean);
var /= (M > 1 ? M - 1 : 1);
return { mean, std::sqrt(var) };
}
};
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Gate statistics output
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
struct GateResult {
float gate; ///< u(s) β [0, 1]
float v_fast; ///< Fast critic value
float v_slow; ///< Slow ensemble mean value
float sigma; ///< Ensemble disagreement (std)
};
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// UGTC Module
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Module {
public:
Module(int obs_dim, const Config& cfg = Config{})
: cfg_(cfg)
, sigma_ema_(1.0f)
{
std::mt19937 rng(42);
fast_critic_ = ValueNetwork(obs_dim, cfg.hidden_dim, rng);
slow_ensemble_ = EnsembleValueNetwork(obs_dim, cfg.hidden_dim, cfg.M, rng);
}
// ββ Gate computation ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
/**
* Compute the uncertainty gate u(s) for a single observation.
*
* Steps:
* 1. Evaluate fast critic: v_fast = V_fast(s)
* 2. Evaluate slow ensemble: (vΜ_slow, Ο) = ensemble(s)
* 3. EMA-normalize: ΟΜ = Ο / Ο_EMA
* 4. Sigmoid gate: u(s) = sigmoid(-Ξ² Β· (ΟΜ - 1))
*
* @param obs Observation vector (obs_dim,)
* @param train Whether to update EMA (true during training)
* @return GateResult with gate, v_fast, v_slow, sigma
*/
GateResult computeGate(const Vector& obs, bool train = false) {
float v_fast = fast_critic_.forward(obs);
auto [v_slow, sigma] = slow_ensemble_.forward(obs);
if (train) {
sigma_ema_ = cfg_.ema_momentum * sigma_ema_
+ (1.0f - cfg_.ema_momentum) * sigma;
}
float normalized_sigma = sigma / (sigma_ema_ + cfg_.eps);
float gate = sigmoid(-cfg_.beta * (normalized_sigma - 1.0f));
return { gate, v_fast, v_slow, sigma };
}
// ββ Value estimation ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
/**
* Blended value estimate V^UGTC(s) = u(s)Β·VΜ_slow(s) + (1-u(s))Β·V_fast(s)
*/
float getValueUGTC(const Vector& obs, bool train = false) {
auto r = computeGate(obs, train);
return r.gate * r.v_slow + (1.0f - r.gate) * r.v_fast;
}
// ββ GAE computation βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
/**
* Standard Generalized Advantage Estimation.
*
* Ξ΄β = rβ + Ξ³Β·V(sβββ)Β·(1-dβ) - V(sβ)
* Aβ = Ξ΄β + γλ·(1-dβ)Β·Aβββ
*
* @param rewards (T,) reward sequence
* @param values (T,) current-state values
* @param next_vals (T,) next-state values
* @param dones (T,) episode termination flags
* @param gamma discount factor
* @param lam GAE lambda
* @return (T,) advantage estimates
*/
static std::vector<float> computeGAE(
const std::vector<float>& rewards,
const std::vector<float>& values,
const std::vector<float>& next_vals,
const std::vector<float>& dones,
float gamma,
float lam
) {
int T = static_cast<int>(rewards.size());
std::vector<float> advantages(T, 0.0f);
float gae = 0.0f;
for (int t = T - 1; t >= 0; --t) {
float delta = rewards[t] + gamma * next_vals[t] * (1.0f - dones[t]) - values[t];
gae = delta + gamma * lam * (1.0f - dones[t]) * gae;
advantages[t] = gae;
}
return advantages;
}
// ββ UGTC advantage ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
/**
* Compute UGTC blended advantages for a trajectory.
*
* A^UGTC_t = u(sβ)Β·A^slow_t + (1-u(sβ))Β·A^fast_t
*
* @param obs_seq Sequence of observations (T Γ obs_dim)
* @param next_obs_seq Sequence of next observations (T Γ obs_dim)
* @param rewards (T,) rewards
* @param dones (T,) done flags
* @param gamma Discount factor
* @param train Whether to update EMA
* @return (T,) UGTC blended advantages
*/
std::vector<float> computeAdvantages(
const std::vector<Vector>& obs_seq,
const std::vector<Vector>& next_obs_seq,
const std::vector<float>& rewards,
const std::vector<float>& dones,
float gamma = 0.99f,
bool train = false
) {
int T = static_cast<int>(obs_seq.size());
assert(T == static_cast<int>(rewards.size()));
std::vector<float> gates(T), v_fast_arr(T), v_slow_arr(T);
std::vector<float> v_fast_next(T), v_slow_next(T);
for (int t = 0; t < T; ++t) {
auto r = computeGate(obs_seq[t], train);
auto r_next = computeGate(next_obs_seq[t], false);
gates[t] = r.gate;
v_fast_arr[t] = r.v_fast;
v_slow_arr[t] = r.v_slow;
v_fast_next[t] = r_next.v_fast;
v_slow_next[t] = r_next.v_slow;
}
auto adv_fast = computeGAE(rewards, v_fast_arr, v_fast_next, dones, gamma, cfg_.lambda_fast);
auto adv_slow = computeGAE(rewards, v_slow_arr, v_slow_next, dones, gamma, cfg_.lambda_slow);
std::vector<float> advantages(T);
for (int t = 0; t < T; ++t) {
advantages[t] = gates[t] * adv_slow[t] + (1.0f - gates[t]) * adv_fast[t];
}
return advantages;
}
// ββ Accessors βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
float getSigmaEMA() const { return sigma_ema_; }
const Config& getConfig() const { return cfg_; }
private:
Config cfg_;
ValueNetwork fast_critic_;
EnsembleValueNetwork slow_ensemble_;
float sigma_ema_;
};
} // namespace UGTC
|