| | |
| | |
| | |
| | |
| | const AgentSac = (() => { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | const assertShape = (tensor, shape, msg = '') => { |
| | console.assert( |
| | JSON.stringify(tensor.shape) === JSON.stringify(shape), |
| | msg + ' shape ' + tensor.shape + ' is not ' + shape) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | const VERSION = 84 |
| |
|
| | const LOG_STD_MIN = -20 |
| | const LOG_STD_MAX = 2 |
| | const EPSILON = 1e-8 |
| | const NAME = { |
| | ACTOR: 'actor', |
| | Q1: 'q1', |
| | Q2: 'q2', |
| | Q1_TARGET: 'q1-target', |
| | Q2_TARGET: 'q2-target', |
| | ALPHA: 'alpha' |
| | } |
| |
|
| | return class AgentSac { |
| | constructor({ |
| | batchSize = 1, |
| | frameShape = [25, 25, 3], |
| | nFrames = 1, // Number of stacked frames per state |
| | nActions = 3, // 3 - impuls, 3 - RGB color |
| | nTelemetry = 10, // 3 - linear valocity, 3 - acceleration, 3 - collision point, 1 - lidar (tanh of distance) |
| | gamma = 0.99, // Discount factor (γ) |
| | tau = 5e-3, // Target smoothing coefficient (τ) |
| | trainable = true, // Whether the actor is trainable |
| | verbose = false, |
| | forced = false, // force to create fresh models (not from checkpoint) |
| | prefix = '', // for tests, |
| | sighted = true, |
| | rewardScale = 10 |
| | } = {}) { |
| | this._batchSize = batchSize |
| | this._frameShape = frameShape |
| | this._nFrames = nFrames |
| | this._nActions = nActions |
| | this._nTelemetry = nTelemetry |
| | this._gamma = gamma |
| | this._tau = tau |
| | this._trainable = trainable |
| | this._verbose = verbose |
| | this._inited = false |
| | this._prefix = (prefix === '' ? '' : prefix + '-') |
| | this._forced = forced |
| | this._sighted = sighted |
| | this._rewardScale = rewardScale |
| | |
| | this._frameStackShape = [...this._frameShape.slice(0, 2), this._frameShape[2] * this._nFrames] |
| |
|
| | |
| | this._targetEntropy = -nActions |
| | } |
| |
|
| | |
| | |
| | |
| | async init() { |
| | if (this._inited) throw Error('щ(゚Д゚щ)') |
| |
|
| | this._frameInputL = tf.input({batchShape : [null, ...this._frameStackShape]}) |
| | this._frameInputR = tf.input({batchShape : [null, ...this._frameStackShape]}) |
| |
|
| | this._telemetryInput = tf.input({batchShape : [null, this._nTelemetry]}) |
| | |
| | this.actor = await this._getActor(this._prefix + NAME.ACTOR, this.trainable) |
| | |
| | if (!this._trainable) |
| | return |
| | |
| | this.actorOptimizer = tf.train.adam() |
| |
|
| | this._actionInput = tf.input({batchShape : [null, this._nActions]}) |
| |
|
| | this.q1 = await this._getCritic(this._prefix + NAME.Q1) |
| | this.q1Optimizer = tf.train.adam() |
| |
|
| | this.q2 = await this._getCritic(this._prefix + NAME.Q2) |
| | this.q2Optimizer = tf.train.adam() |
| |
|
| | this.q1Targ = await this._getCritic(this._prefix + NAME.Q1_TARGET, true) |
| | this.q2Targ = await this._getCritic(this._prefix + NAME.Q2_TARGET, true) |
| |
|
| | this._logAlpha = await this._getLogAlpha(this._prefix + NAME.ALPHA) |
| | this.alphaOptimizer = tf.train.adam() |
| |
|
| | this.updateTargets(1) |
| |
|
| | |
| | |
| | |
| |
|
| | this._inited = true |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | train({ state, action, reward, nextState }) { |
| | if (!this._trainable) |
| | throw new Error('Actor is not trainable') |
| |
|
| | return tf.tidy(() => { |
| | assertShape(state[0], [this._batchSize, this._nTelemetry], 'telemetry') |
| | assertShape(state[1], [this._batchSize, ...this._frameStackShape], 'frames') |
| | assertShape(action, [this._batchSize, this._nActions], 'action') |
| | assertShape(reward, [this._batchSize, 1], 'reward') |
| | assertShape(nextState[0], [this._batchSize, this._nTelemetry], 'nextState telemetry') |
| | assertShape(nextState[1], [this._batchSize, ...this._frameStackShape], 'nextState frames') |
| |
|
| | this._trainCritics({ state, action, reward, nextState }) |
| | this._trainActor(state) |
| | this._trainAlpha(state) |
| | |
| | this.updateTargets() |
| | }) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | _trainCritics({ state, action, reward, nextState }) { |
| | const getQLossFunction = (() => { |
| | const [nextFreshAction, logPi] = this.sampleAction(nextState, true) |
| |
|
| | const q1TargValue = this.q1Targ.predict( |
| | this._sighted ? [...nextState, nextFreshAction] : [nextState[0], nextFreshAction], |
| | {batchSize: this._batchSize}) |
| | const q2TargValue = this.q2Targ.predict( |
| | this._sighted ? [...nextState, nextFreshAction] : [nextState[0], nextFreshAction], |
| | {batchSize: this._batchSize}) |
| | |
| | const qTargValue = tf.minimum(q1TargValue, q2TargValue) |
| | |
| | |
| | const alpha = this._getAlpha() |
| | const target = reward.mul(tf.scalar(this._rewardScale)).add( |
| | tf.scalar(this._gamma).mul( |
| | qTargValue.sub(alpha.mul(logPi)) |
| | ) |
| | ) |
| | |
| | assertShape(nextFreshAction, [this._batchSize, this._nActions], 'nextFreshAction') |
| | assertShape(logPi, [this._batchSize, 1], 'logPi') |
| | assertShape(qTargValue, [this._batchSize, 1], 'qTargValue') |
| | assertShape(target, [this._batchSize, 1], 'target') |
| | |
| | return (q) => () => { |
| | const qValue = q.predict( |
| | this._sighted ? [...state, action] : [state[0], action], |
| | {batchSize: this._batchSize}) |
| | |
| | |
| | const loss = tf.scalar(0.5).mul(tf.mean(qValue.sub(target).square())) |
| | |
| | assertShape(qValue, [this._batchSize, 1], 'qValue') |
| |
|
| | return loss |
| | } |
| | })() |
| | |
| | for (const [q, optimizer] of [ |
| | [this.q1, this.q1Optimizer], |
| | [this.q2, this.q2Optimizer] |
| | ]) { |
| | const qLossFunction = getQLossFunction(q) |
| | |
| | const { value, grads } = tf.variableGrads(qLossFunction, q.getWeights(true)) |
| | |
| | optimizer.applyGradients(grads) |
| | |
| | if (this._verbose) console.log(q.name + ' Loss: ' + value.arraySync()) |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | _trainActor(state) { |
| | |
| | const actorLossFunction = () => { |
| | const [freshAction, logPi] = this.sampleAction(state, true) |
| | |
| | const q1Value = this.q1.predict( |
| | this._sighted ? [...state, freshAction] : [state[0], freshAction], |
| | {batchSize: this._batchSize}) |
| | const q2Value = this.q2.predict( |
| | this._sighted ? [...state, freshAction] : [state[0], freshAction], |
| | {batchSize: this._batchSize}) |
| | |
| | const criticValue = tf.minimum(q1Value, q2Value) |
| |
|
| | const alpha = this._getAlpha() |
| | const loss = alpha.mul(logPi).sub(criticValue) |
| |
|
| | assertShape(freshAction, [this._batchSize, this._nActions], 'freshAction') |
| | assertShape(logPi, [this._batchSize, 1], 'logPi') |
| | assertShape(q1Value, [this._batchSize, 1], 'q1Value') |
| | assertShape(criticValue, [this._batchSize, 1], 'criticValue') |
| | assertShape(loss, [this._batchSize, 1], 'alpha loss') |
| |
|
| | return tf.mean(loss) |
| | } |
| | |
| | const { value, grads } = tf.variableGrads(actorLossFunction, this.actor.getWeights(true)) |
| | |
| | this.actorOptimizer.applyGradients(grads) |
| |
|
| | if (this._verbose) console.log('Actor Loss: ' + value.arraySync()) |
| | } |
| |
|
| | _trainAlpha(state) { |
| | const alphaLossFunction = () => { |
| | const [, logPi] = this.sampleAction(state, true) |
| |
|
| | const alpha = this._getAlpha() |
| | const loss = tf.scalar(-1).mul( |
| | alpha.mul( |
| | logPi.add(tf.scalar(this._targetEntropy)) |
| | ) |
| | ) |
| |
|
| | assertShape(loss, [this._batchSize, 1], 'alpha loss') |
| |
|
| | return tf.mean(loss) |
| | } |
| | |
| | const { value, grads } = tf.variableGrads(alphaLossFunction, [this._logAlpha]) |
| | |
| | this.alphaOptimizer.applyGradients(grads) |
| | |
| | if (this._verbose) console.log('Alpha Loss: ' + value.arraySync(), tf.exp(this._logAlpha).arraySync()) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | updateTargets(tau = this._tau) { |
| | tau = tf.scalar(tau) |
| |
|
| | const |
| | q1W = this.q1.getWeights(), |
| | q2W = this.q2.getWeights(), |
| | q1WTarg = this.q1Targ.getWeights(), |
| | q2WTarg = this.q2Targ.getWeights(), |
| | len = q1W.length |
| |
|
| | |
| | |
| |
|
| | const calc = (w, wTarg) => wTarg.mul(tf.scalar(1).sub(tau)).add(w.mul(tau)) |
| | |
| | const w1 = [], w2 = [] |
| | for (let i = 0; i < len; i++) { |
| | w1.push(calc(q1W[i], q1WTarg[i])) |
| | w2.push(calc(q2W[i], q2WTarg[i])) |
| | } |
| | |
| | this.q1Targ.setWeights(w1) |
| | this.q2Targ.setWeights(w2) |
| |
|
| |
|
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | sampleAction(state, withLogProbs = false) { |
| | return tf.tidy(() => { |
| | let [ mu, logStd ] = this.actor.predict(this._sighted ? state : state[0], {batchSize: this._batchSize}) |
| |
|
| | |
| | logStd = tf.clipByValue(logStd, LOG_STD_MIN, LOG_STD_MAX) |
| | |
| | const std = tf.exp(logStd) |
| |
|
| | |
| | const normal = tf.randomNormal(mu.shape, 0, 1.0) |
| | |
| | |
| | let pi = mu.add(std.mul(normal)) |
| |
|
| | let logPi = this._gaussianLikelihood(pi, mu, logStd) |
| |
|
| | ;({ pi, logPi } = this._applySquashing(pi, mu, logPi)) |
| |
|
| | if (!withLogProbs) |
| | return pi |
| | |
| | return [pi, logPi] |
| | }) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | _logProb(x, mu, std) { |
| | const logUnnormalized = tf.scalar(-0.5).mul( |
| | tf.squaredDifference(x.div(std), mu.div(std)) |
| | ) |
| | const logNormalization = tf.scalar(0.5 * Math.log(2 * Math.PI)).add(tf.log(std)) |
| | |
| | return logUnnormalized.sub(logNormalization) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | _gaussianLikelihood(x, mu, logStd) { |
| | |
| | |
| | |
| | |
| | |
| |
|
| | const preSum = tf.scalar(-0.5).mul( |
| | x.sub(mu).div( |
| | tf.exp(logStd).add(tf.scalar(EPSILON)) |
| | ).square() |
| | .add(tf.scalar(2).mul(logStd)) |
| | .add(tf.scalar(Math.log(2 * Math.PI))) |
| | ) |
| |
|
| | return tf.sum(preSum, 1, true) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | _applySquashing(pi, mu, logPi) { |
| | |
| |
|
| | const adj = tf.scalar(2).mul( |
| | tf.scalar(Math.log(2)) |
| | .sub(pi) |
| | .sub(tf.softplus( |
| | tf.scalar(-2).mul(pi) |
| | )) |
| | ) |
| |
|
| | logPi = logPi.sub(tf.sum(adj, 1, true)) |
| | mu = tf.tanh(mu) |
| | pi = tf.tanh(pi) |
| |
|
| | return { pi, mu, logPi } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | async _getActor(name = 'actor', trainable = true) { |
| | const checkpoint = await this._loadCheckpoint(name) |
| | if (checkpoint) return checkpoint |
| |
|
| | let outputs = this._telemetryInput |
| | |
| |
|
| | if (this._sighted) { |
| | let convOutputL = this._getConvEncoder(this._frameInputL) |
| | let convOutputR = this._getConvEncoder(this._frameInputR) |
| | |
| | |
| |
|
| | outputs = tf.layers.concatenate().apply([convOutputL, convOutputR, outputs]) |
| | } |
| |
|
| | outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs) |
| | outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs) |
| |
|
| | const mu = tf.layers.dense({units: this._nActions}).apply(outputs) |
| | const logStd = tf.layers.dense({units: this._nActions}).apply(outputs) |
| |
|
| | const model = tf.model({inputs: this._sighted ? [this._telemetryInput, this._frameInputL, this._frameInputR] : [this._telemetryInput], outputs: [mu, logStd], name}) |
| | model.trainable = trainable |
| |
|
| | if (this._verbose) { |
| | console.log('==========================') |
| | console.log('==========================') |
| | console.log('Actor ' + name + ': ') |
| |
|
| | model.summary() |
| | } |
| |
|
| | return model |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | async _getCritic(name = 'critic', trainable = true) { |
| | const checkpoint = await this._loadCheckpoint(name) |
| | if (checkpoint) return checkpoint |
| |
|
| | let outputs = tf.layers.concatenate().apply([this._telemetryInput, this._actionInput]) |
| | |
| |
|
| | if (this._sighted) { |
| | let convOutputL = this._getConvEncoder(this._frameInputL) |
| | let convOutputR = this._getConvEncoder(this._frameInputR) |
| | |
| | |
| |
|
| | outputs = tf.layers.concatenate().apply([convOutputL, convOutputR, outputs]) |
| | } |
| |
|
| | outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs) |
| | outputs = tf.layers.dense({units: 256, activation: 'relu'}).apply(outputs) |
| |
|
| | outputs = tf.layers.dense({units: 1}).apply(outputs) |
| |
|
| | const model = tf.model({ |
| | inputs: this._sighted |
| | ? [this._telemetryInput, this._frameInputL, this._frameInputR, this._actionInput] |
| | : [this._telemetryInput, this._actionInput], |
| | outputs, name |
| | }) |
| |
|
| | model.trainable = trainable |
| |
|
| | if (this._verbose) { |
| | console.log('==========================') |
| | console.log('==========================') |
| | console.log('CRITIC ' + name + ': ') |
| | |
| | model.summary() |
| | } |
| |
|
| | return model |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | _getConvEncoder(inputs) { |
| | const kernelSize = 3 |
| | const padding = 'valid' |
| | const poolSize = 3 |
| | const strides = 1 |
| | |
| | |
| | const kernelInitializer = 'glorotNormal' |
| | const biasInitializer = 'glorotNormal' |
| |
|
| | let outputs = inputs |
| | |
| | |
| | outputs = tf.layers.conv2d({ |
| | filters: 16, |
| | kernelSize: 5, |
| | strides: 2, |
| | padding, |
| | kernelInitializer, |
| | biasInitializer, |
| | activation: 'relu', |
| | trainable: true |
| | }).apply(outputs) |
| | outputs = tf.layers.maxPooling2d({poolSize:2}).apply(outputs) |
| | |
| | |
| |
|
| | outputs = tf.layers.conv2d({ |
| | filters: 16, |
| | kernelSize: 3, |
| | strides: 1, |
| | padding, |
| | kernelInitializer, |
| | biasInitializer, |
| | activation: 'relu', |
| | trainable: true |
| | }).apply(outputs) |
| | outputs = tf.layers.maxPooling2d({poolSize:2}).apply(outputs) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | outputs = tf.layers.flatten().apply(outputs) |
| |
|
| | |
| |
|
| | return outputs |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | _getAlpha() { |
| | |
| | return tf.exp(this._logAlpha) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | async _getLogAlpha(name = 'alpha') { |
| | let logAlpha = 0.0 |
| |
|
| | const checkpoint = await this._loadCheckpoint(name) |
| | if (checkpoint) { |
| | logAlpha = checkpoint.getWeights()[0].arraySync()[0][0] |
| |
|
| | if (this._verbose) |
| | console.log('Checkpoint alpha: ', logAlpha) |
| | |
| | this._logAlphaPlaceholder = checkpoint |
| | } else { |
| | const model = tf.sequential({ name }); |
| | model.add(tf.layers.dense({ units: 1, inputShape: [1], useBias: false })) |
| | model.setWeights([tf.tensor([logAlpha], [1, 1])]) |
| |
|
| | this._logAlphaPlaceholder = model |
| | } |
| |
|
| | return tf.variable(tf.scalar(logAlpha), true) |
| | } |
| |
|
| | |
| | |
| | |
| | async checkpoint() { |
| | if (!this._trainable) throw new Error('(╭ರ_ ⊙ )') |
| |
|
| | this._logAlphaPlaceholder.setWeights([tf.tensor([this._logAlpha.arraySync()], [1, 1])]) |
| |
|
| | await Promise.all([ |
| | this._saveCheckpoint(this.actor), |
| | this._saveCheckpoint(this.q1), |
| | this._saveCheckpoint(this.q2), |
| | this._saveCheckpoint(this.q1Targ), |
| | this._saveCheckpoint(this.q2Targ), |
| | this._saveCheckpoint(this._logAlphaPlaceholder) |
| | ]) |
| |
|
| | if (this._verbose) |
| | console.log('Checkpoint succesfully saved') |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async _saveCheckpoint(model) { |
| | const key = this._getChKey(model.name) |
| | const saveResults = await model.save(key) |
| |
|
| | if (this._verbose) |
| | console.log('Checkpoint saveResults', model.name, saveResults) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | async _loadCheckpoint(name) { |
| | |
| | if (this._forced) { |
| | console.log('Forced to not load from the checkpoint ' + name) |
| | return |
| | } |
| |
|
| | const key = this._getChKey(name) |
| | const modelsInfo = await tf.io.listModels() |
| |
|
| | if (key in modelsInfo) { |
| | const model = await tf.loadLayersModel(key) |
| |
|
| | if (this._verbose) |
| | console.log('Loaded checkpoint for ' + name) |
| |
|
| | return model |
| | } |
| | |
| | if (this._verbose) |
| | console.log('Checkpoint not found for ' + name) |
| | } |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | _getChKey(name) { |
| | return 'indexeddb://' + name + '-' + VERSION |
| | } |
| | } |
| | })() |
| |
|
| | |
| | ;(async () => { |
| | return |
| |
|
| | |
| | ;(() => { |
| | const agent = new AgentSac() |
| |
|
| | const |
| | mu = tf.tensor([0], [1, 1]), |
| | logStd = tf.tensor([0], [1, 1]), |
| | std = tf.exp(logStd), |
| | normal = tf.tensor([0], [1, 1]), |
| | pi = mu.add(std.mul(normal)) |
| | |
| | const log = agent._gaussianLikelihood(pi, mu, logStd) |
| |
|
| | console.assert(log.arraySync()[0][0].toFixed(5) === '-0.91894', |
| | 'test Gaussian Likelihood for μ=0, σ=1, x=0') |
| | })() |
| |
|
| | ;(() => { |
| | const agent = new AgentSac() |
| |
|
| | const |
| | mu = tf.tensor([1], [1, 1]), |
| | logStd = tf.tensor([1], [1, 1]), |
| | std = tf.exp(logStd), |
| | normal = tf.tensor([0], [1, 1]), |
| | pi = mu.add(std.mul(normal)) |
| | |
| | const log = agent._gaussianLikelihood(pi, mu, logStd) |
| |
|
| | console.assert(log.arraySync()[0][0].toFixed(5) === '-1.91894', |
| | 'test Gaussian Likelihood for μ=1, σ=e, x=0') |
| | })() |
| |
|
| | ;(() => { |
| | const agent = new AgentSac() |
| |
|
| | const |
| | mu = tf.tensor([1], [1, 1]), |
| | logStd = tf.tensor([1], [1, 1]), |
| | std = tf.exp(logStd), |
| | normal = tf.tensor([0.1], [1, 1]), |
| | pi = mu.add(std.mul(normal)) |
| | |
| | const logPi = agent._gaussianLikelihood(pi, mu, logStd) |
| | const { pi: piSquashed, logPi: logPiSquashed } = agent._applySquashing(pi, mu, logPi) |
| |
|
| | const logProbBounded = logPi.sub( |
| | tf.log( |
| | tf.scalar(1) |
| | .sub(tf.tanh(pi).pow(tf.scalar(2))) |
| | |
| | ) |
| | ).sum(1, true) |
| | |
| | console.assert(logPi.arraySync()[0][0].toFixed(5) === '-1.92394', |
| | 'test Gaussian Likelihood for μ=-1, σ=e, x=-1.27182818') |
| |
|
| | console.assert(logPiSquashed.arraySync()[0][0].toFixed(5) === logProbBounded.arraySync()[0][0].toFixed(5), |
| | 'test logPiSquashed for μ=-1, σ=e, x=-1.27182818') |
| |
|
| | console.assert(piSquashed.arraySync()[0][0].toFixed(5) === tf.tanh(pi).arraySync()[0][0].toFixed(5), |
| | 'test piSquashed for μ=-1, σ=e, x=-1.27182818') |
| | })() |
| |
|
| | await (async () => { |
| | const state = tf.tensor([ |
| | 0.5, 0.3, -0.9, |
| | 0, -0.8, 1, |
| | -0.3, 0.04, 0.02, |
| | 0.9 |
| | ], [1, 10]) |
| |
|
| | const action = tf.tensor([ |
| | 0.1, -1, -0.4, |
| | 1, -0.8, -0.8, -0.2, |
| | 0.04, 0.02, 0.001 |
| | ], [1, 10]) |
| | |
| | const fresh = new AgentSac({ prefix: 'test', forced: true }) |
| | await fresh.init() |
| | await fresh.checkpoint() |
| | |
| | const saved = new AgentSac({ prefix: 'test' }) |
| | await saved.init() |
| | |
| | let frPred, saPred |
| |
|
| | frPred = fresh.actor.predict(state, {batchSize: 1}) |
| | saPred = saved.actor.predict(state, {batchSize: 1}) |
| | console.assert( |
| | frPred[0].arraySync().length > 0 && |
| | frPred[1].arraySync().length > 0 && |
| | frPred[0].arraySync().join(';') === saPred[0].arraySync().join(';') && |
| | frPred[1].arraySync().join(';') === saPred[1].arraySync().join(';'), |
| | 'Models loaded from the checkpoint should be the same') |
| | |
| | frPred = fresh.q1.predict([state, action], {batchSize: 1}) |
| | saPred = fresh.q1Targ.predict([state, action], {batchSize: 1}) |
| | console.assert( |
| | frPred.arraySync()[0][0] !== undefined && |
| | frPred.arraySync()[0][0] === saPred.arraySync()[0][0], |
| | 'Q1 and Q1-target should be the same') |
| |
|
| | frPred = fresh.q2.predict([state, action], {batchSize: 1}) |
| | saPred = saved.q2.predict([state, action], {batchSize: 1}) |
| | console.assert( |
| | frPred.arraySync()[0][0] !== undefined && |
| | frPred.arraySync()[0][0] === saPred.arraySync()[0][0], |
| | 'Q and Q restored should be the same') |
| |
|
| | console.assert( |
| | fresh._logAlpha.arraySync() !== undefined && |
| | fresh._logAlpha.arraySync() === fresh._logAlpha.arraySync(), |
| | 'Q and Q restored should be the same') |
| | })() |
| | })() |
| |
|