Aksh Parekh Claude Opus 4.6 commited on
Commit
0977416
Β·
1 Parent(s): 68d6c60

feat: dual-axis reward chart with cumulative net, dynamic mode label, and mean

Browse files

- Track cumulative (net) reward across all episodes
- Reward chart shows per-episode (blue, left axis) + net cumulative (orange, right axis)
- Chart title dynamically shows CAPPED/UNCAPPED mode, mean, and net total
- X-axis always spans t=0 to current t (scales to fit all data)
- Remove episode_history cap so full history is preserved

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. server/app.py +158 -33
  2. server/requirements.txt +3 -0
server/app.py CHANGED
@@ -104,6 +104,7 @@ class TrainingState:
104
  total_responses: int = 0
105
  # History for charts
106
  reward_history: List[float] = field(default_factory=list) # per-episode cumulative
 
107
  accuracy_history: List[float]= field(default_factory=list) # per-episode correct%
108
  episode_history: List[Dict] = field(default_factory=list)
109
  # PPO
@@ -305,13 +306,10 @@ def _training_loop() -> None:
305
  _state.mean_reward_100 = round(float(np.mean(ep_rewards)), 2)
306
  _state.mean_ep_len = round(float(np.mean(ep_lengths)), 1)
307
  _state.reward_history.append(round(ep_reward, 2))
 
 
308
  _state.accuracy_history.append(round(acc, 1))
309
- if len(_state.reward_history) > 500:
310
- _state.reward_history = _state.reward_history[-500:]
311
- _state.accuracy_history = _state.accuracy_history[-500:]
312
  _state.episode_history.append(ep_rec)
313
- if len(_state.episode_history) > 200:
314
- _state.episode_history = _state.episode_history[-200:]
315
 
316
  _push_sse({"type": "episode", "data": ep_rec})
317
 
@@ -419,8 +417,9 @@ def get_state():
419
  "pg_loss": s.last_pg_loss,
420
  "vf_loss": s.last_vf_loss,
421
  "entropy": s.last_entropy,
422
- "reward_history": s.reward_history[-300:],
423
- "accuracy_history": s.accuracy_history[-300:],
 
424
  "episode_history": s.episode_history[-50:],
425
  "incident_feed": s.incident_feed[-30:],
426
  "incident_counts": s.incident_counts,
@@ -593,7 +592,7 @@ td { padding:4px 8px; border-bottom:1px solid #13131a; }
593
  <!-- RIGHT -->
594
  <div class="right">
595
  <div class="chart-wrap" style="flex:0 0 45%">
596
- <div class="chart-title">EPISODE REWARD HISTORY</div>
597
  <canvas class="chart" id="rwChart"></canvas>
598
  </div>
599
  <div class="chart-wrap" style="flex:0 0 28%">
@@ -619,7 +618,7 @@ td { padding:4px 8px; border-bottom:1px solid #13131a; }
619
  <script>
620
  // ── State ──────────────────────────────────────────────────────────────────
621
  let S = {
622
- cars:[], reward_history:[], accuracy_history:[], episode_history:[],
623
  incident_feed:[], incident_counts:{}, stage:1, reward_mode:'capped',
624
  response_accuracy:0, total_steps:0, n_episodes:0, n_updates:0,
625
  ego_x:0, goal_x:180, episode_reward:0, episode_steps:0,
@@ -687,6 +686,8 @@ function drawRoad() {
687
  }
688
 
689
  // ── Chart drawing ────────────────────────────────────────────────────────
 
 
690
  function drawLineChart(canvasId, data, color, label, yMin, yMax, showZero) {
691
  const canvas = document.getElementById(canvasId);
692
  const w = canvas.offsetWidth||400, h = canvas.offsetHeight||160;
@@ -695,15 +696,19 @@ function drawLineChart(canvasId, data, color, label, yMin, yMax, showZero) {
695
  ctx.clearRect(0,0,w,h);
696
  if(!data||data.length<2){
697
  ctx.fillStyle='#445'; ctx.font='11px Courier New'; ctx.textAlign='center';
698
- ctx.fillText('Waiting for data...', w/2, h/2);
699
  return;
700
  }
701
- const pad={t:8,r:8,b:22,l:48};
702
  const pw=w-pad.l-pad.r, ph=h-pad.t-pad.b;
703
  const mn = yMin!==undefined?yMin:Math.min(...data);
704
  const mx = yMax!==undefined?yMax:Math.max(...data);
705
  const rng = mx-mn||1;
706
- // Grid
 
 
 
 
707
  ctx.strokeStyle='#1a1a28'; ctx.lineWidth=1;
708
  for(let i=0;i<=4;i++){
709
  const y=pad.t+ph*(i/4);
@@ -711,33 +716,150 @@ function drawLineChart(canvasId, data, color, label, yMin, yMax, showZero) {
711
  ctx.fillStyle='#445'; ctx.font='8px Courier New'; ctx.textAlign='right';
712
  ctx.fillText((mx-rng*(i/4)).toFixed(1), pad.l-3, y+3);
713
  }
 
714
  // Zero line
715
  if(showZero && mn<0 && mx>0){
716
- const zy=pad.t+(mx/rng)*ph;
717
- ctx.strokeStyle='#3a3a50'; ctx.setLineDash([4,4]);
718
  ctx.beginPath();ctx.moveTo(pad.l,zy);ctx.lineTo(pad.l+pw,zy);ctx.stroke();
719
  ctx.setLineDash([]);
720
  }
721
- // MA-10
722
- const ma=data.map((_,i)=>{
723
- const sl=data.slice(Math.max(0,i-9),i+1);
 
 
724
  return sl.reduce((a,b)=>a+b,0)/sl.length;
725
  });
726
- // Raw
727
- ctx.strokeStyle=color+'44'; ctx.lineWidth=1; ctx.beginPath();
728
- data.forEach((v,i)=>{
729
- const x=pad.l+i*(pw/(data.length-1)), y=pad.t+(mx-v)/rng*ph;
730
- i?ctx.lineTo(x,y):ctx.moveTo(x,y);
731
- }); ctx.stroke();
732
- // Smoothed
 
 
 
 
 
 
 
 
 
733
  ctx.strokeStyle=color; ctx.lineWidth=2; ctx.beginPath();
734
- ma.forEach((v,i)=>{
735
- const x=pad.l+i*(pw/(ma.length-1)), y=pad.t+(mx-v)/rng*ph;
736
- i?ctx.lineTo(x,y):ctx.moveTo(x,y);
737
- }); ctx.stroke();
738
- // X label
739
- ctx.fillStyle='#445'; ctx.font='8px Courier New'; ctx.textAlign='center';
740
- ctx.fillText(label+' ('+data.length+')', pad.l+pw/2, h-4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
  }
742
 
743
  // ── Incident feed ────────────────────────────────────────────────────────
@@ -833,7 +955,7 @@ function updateUI() {
833
  // ── Render all ────────────────────────────────────────────────────────────
834
  function renderAll() {
835
  drawRoad();
836
- drawLineChart('rwChart', S.reward_history, '#7eb8ff', 'Episodes', undefined, undefined, true);
837
  drawLineChart('accChart', S.accuracy_history, '#4caf50', 'Episodes', 0, 100, false);
838
  renderFeed(S.incident_feed||[]);
839
  renderEpTable(S.episode_history||[]);
@@ -861,10 +983,13 @@ evtSrc.onmessage = (e) => {
861
  S.episode_history.push(msg.data);
862
  if(!S.reward_history) S.reward_history=[];
863
  S.reward_history.push(msg.data.reward);
 
 
 
864
  if(!S.accuracy_history) S.accuracy_history=[];
865
  S.accuracy_history.push(msg.data.accuracy||0);
866
  renderEpTable(S.episode_history);
867
- drawLineChart('rwChart', S.reward_history, '#7eb8ff', 'Episodes', undefined, undefined, true);
868
  drawLineChart('accChart', S.accuracy_history, '#4caf50', 'Episodes', 0, 100, false);
869
  } else if(msg.type==='tick'){
870
  Object.assign(S, msg.data);
 
104
  total_responses: int = 0
105
  # History for charts
106
  reward_history: List[float] = field(default_factory=list) # per-episode cumulative
107
+ cumulative_reward: List[float] = field(default_factory=list) # running net total
108
  accuracy_history: List[float]= field(default_factory=list) # per-episode correct%
109
  episode_history: List[Dict] = field(default_factory=list)
110
  # PPO
 
306
  _state.mean_reward_100 = round(float(np.mean(ep_rewards)), 2)
307
  _state.mean_ep_len = round(float(np.mean(ep_lengths)), 1)
308
  _state.reward_history.append(round(ep_reward, 2))
309
+ prev_cum = _state.cumulative_reward[-1] if _state.cumulative_reward else 0.0
310
+ _state.cumulative_reward.append(round(prev_cum + ep_reward, 2))
311
  _state.accuracy_history.append(round(acc, 1))
 
 
 
312
  _state.episode_history.append(ep_rec)
 
 
313
 
314
  _push_sse({"type": "episode", "data": ep_rec})
315
 
 
417
  "pg_loss": s.last_pg_loss,
418
  "vf_loss": s.last_vf_loss,
419
  "entropy": s.last_entropy,
420
+ "reward_history": s.reward_history,
421
+ "cumulative_reward": s.cumulative_reward,
422
+ "accuracy_history": s.accuracy_history,
423
  "episode_history": s.episode_history[-50:],
424
  "incident_feed": s.incident_feed[-30:],
425
  "incident_counts": s.incident_counts,
 
592
  <!-- RIGHT -->
593
  <div class="right">
594
  <div class="chart-wrap" style="flex:0 0 45%">
595
+ <div class="chart-title" id="rw-title">EPISODE REWARD β€” CAPPED | mean: 0.00 | net: 0.00</div>
596
  <canvas class="chart" id="rwChart"></canvas>
597
  </div>
598
  <div class="chart-wrap" style="flex:0 0 28%">
 
618
  <script>
619
  // ── State ──────────────────────────────────────────────────────────────────
620
  let S = {
621
+ cars:[], reward_history:[], cumulative_reward:[], accuracy_history:[], episode_history:[],
622
  incident_feed:[], incident_counts:{}, stage:1, reward_mode:'capped',
623
  response_accuracy:0, total_steps:0, n_episodes:0, n_updates:0,
624
  ego_x:0, goal_x:180, episode_reward:0, episode_steps:0,
 
686
  }
687
 
688
  // ── Chart drawing ────────────────────────────────────────────────────────
689
+ // Shows ALL data from t=0 to current t. X-axis scales to fit all episodes.
690
+ // Draws: raw (faint), MA (bright), global mean (dashed yellow).
691
  function drawLineChart(canvasId, data, color, label, yMin, yMax, showZero) {
692
  const canvas = document.getElementById(canvasId);
693
  const w = canvas.offsetWidth||400, h = canvas.offsetHeight||160;
 
696
  ctx.clearRect(0,0,w,h);
697
  if(!data||data.length<2){
698
  ctx.fillStyle='#445'; ctx.font='11px Courier New'; ctx.textAlign='center';
699
+ ctx.fillText('Waiting for episodes...', w/2, h/2);
700
  return;
701
  }
702
+ const pad={t:10,r:10,b:24,l:52};
703
  const pw=w-pad.l-pad.r, ph=h-pad.t-pad.b;
704
  const mn = yMin!==undefined?yMin:Math.min(...data);
705
  const mx = yMax!==undefined?yMax:Math.max(...data);
706
  const rng = mx-mn||1;
707
+ const n = data.length;
708
+ const xOf = i => pad.l + i*(pw/(n-1||1));
709
+ const yOf = v => pad.t + (mx-v)/rng*ph;
710
+
711
+ // Grid lines + Y labels
712
  ctx.strokeStyle='#1a1a28'; ctx.lineWidth=1;
713
  for(let i=0;i<=4;i++){
714
  const y=pad.t+ph*(i/4);
 
716
  ctx.fillStyle='#445'; ctx.font='8px Courier New'; ctx.textAlign='right';
717
  ctx.fillText((mx-rng*(i/4)).toFixed(1), pad.l-3, y+3);
718
  }
719
+
720
  // Zero line
721
  if(showZero && mn<0 && mx>0){
722
+ const zy=yOf(0);
723
+ ctx.strokeStyle='#3a3a50'; ctx.lineWidth=1; ctx.setLineDash([4,4]);
724
  ctx.beginPath();ctx.moveTo(pad.l,zy);ctx.lineTo(pad.l+pw,zy);ctx.stroke();
725
  ctx.setLineDash([]);
726
  }
727
+
728
+ // MA window: adaptive
729
+ const MA = Math.max(5, Math.min(30, Math.floor(n/10)));
730
+ const ma = data.map((_,i)=>{
731
+ const sl=data.slice(Math.max(0,i-MA+1),i+1);
732
  return sl.reduce((a,b)=>a+b,0)/sl.length;
733
  });
734
+
735
+ // Global mean (horizontal dashed line)
736
+ const globalMean = data.reduce((a,b)=>a+b,0)/n;
737
+ const gy = yOf(globalMean);
738
+ ctx.strokeStyle='rgba(255,235,59,0.6)'; ctx.lineWidth=1; ctx.setLineDash([6,4]);
739
+ ctx.beginPath();ctx.moveTo(pad.l,gy);ctx.lineTo(pad.l+pw,gy);ctx.stroke();
740
+ ctx.setLineDash([]);
741
+ ctx.fillStyle='rgba(255,235,59,0.8)'; ctx.font='bold 9px Courier New'; ctx.textAlign='left';
742
+ ctx.fillText('\u03bc='+globalMean.toFixed(2), pad.l+4, gy-5);
743
+
744
+ // Raw line (faint)
745
+ ctx.strokeStyle=color+'33'; ctx.lineWidth=1; ctx.beginPath();
746
+ data.forEach((v,i)=>{ i?ctx.lineTo(xOf(i),yOf(v)):ctx.moveTo(xOf(i),yOf(v)); });
747
+ ctx.stroke();
748
+
749
+ // Smoothed MA line
750
  ctx.strokeStyle=color; ctx.lineWidth=2; ctx.beginPath();
751
+ ma.forEach((v,i)=>{ i?ctx.lineTo(xOf(i),yOf(v)):ctx.moveTo(xOf(i),yOf(v)); });
752
+ ctx.stroke();
753
+
754
+ // X-axis: t=0 on left, current episode on right
755
+ ctx.fillStyle='#445'; ctx.font='8px Courier New';
756
+ ctx.textAlign='left'; ctx.fillText('t=0', pad.l, h-4);
757
+ ctx.textAlign='right'; ctx.fillText('t='+n, pad.l+pw, h-4);
758
+ ctx.textAlign='center';ctx.fillText(label, pad.l+pw/2, h-4);
759
+ }
760
+
761
+ // Reward chart with dual Y-axes: per-episode reward (left) + cumulative net (right)
762
+ function drawRewardChart() {
763
+ const canvas = document.getElementById('rwChart');
764
+ const w = canvas.offsetWidth||400, h = canvas.offsetHeight||160;
765
+ canvas.width=w; canvas.height=h;
766
+ const ctx=canvas.getContext('2d');
767
+ ctx.clearRect(0,0,w,h);
768
+ const data = S.reward_history||[];
769
+ const cumul = S.cumulative_reward||[];
770
+ if(!data||data.length<2){
771
+ ctx.fillStyle='#445'; ctx.font='11px Courier New'; ctx.textAlign='center';
772
+ ctx.fillText('Waiting for episodes...', w/2, h/2);
773
+ return;
774
+ }
775
+ const pad={t:10,r:52,b:24,l:52};
776
+ const pw=w-pad.l-pad.r, ph=h-pad.t-pad.b;
777
+ const n = data.length;
778
+ const xOf = i => pad.l + i*(pw/(n-1||1));
779
+
780
+ // Left Y: per-episode reward
781
+ const mn1 = Math.min(...data);
782
+ const mx1 = Math.max(...data);
783
+ const rng1 = mx1-mn1||1;
784
+ const yOf1 = v => pad.t + (mx1-v)/rng1*ph;
785
+
786
+ // Right Y: cumulative net reward
787
+ const mn2 = cumul.length? Math.min(...cumul) : 0;
788
+ const mx2 = cumul.length? Math.max(...cumul) : 1;
789
+ const rng2 = mx2-mn2||1;
790
+ const yOf2 = v => pad.t + (mx2-v)/rng2*ph;
791
+
792
+ // Grid lines + left Y labels
793
+ ctx.strokeStyle='#1a1a28'; ctx.lineWidth=1;
794
+ for(let i=0;i<=4;i++){
795
+ const y=pad.t+ph*(i/4);
796
+ ctx.beginPath();ctx.moveTo(pad.l,y);ctx.lineTo(pad.l+pw,y);ctx.stroke();
797
+ ctx.fillStyle='#556'; ctx.font='8px Courier New'; ctx.textAlign='right';
798
+ ctx.fillText((mx1-rng1*(i/4)).toFixed(1), pad.l-3, y+3);
799
+ }
800
+ // Right Y labels (cumulative)
801
+ for(let i=0;i<=4;i++){
802
+ const y=pad.t+ph*(i/4);
803
+ ctx.fillStyle='#ff985580'; ctx.font='8px Courier New'; ctx.textAlign='left';
804
+ ctx.fillText((mx2-rng2*(i/4)).toFixed(0), pad.l+pw+3, y+3);
805
+ }
806
+
807
+ // Zero line
808
+ if(mn1<0 && mx1>0){
809
+ const zy=yOf1(0);
810
+ ctx.strokeStyle='#3a3a50'; ctx.lineWidth=1; ctx.setLineDash([4,4]);
811
+ ctx.beginPath();ctx.moveTo(pad.l,zy);ctx.lineTo(pad.l+pw,zy);ctx.stroke();
812
+ ctx.setLineDash([]);
813
+ }
814
+
815
+ // Global mean (horizontal dashed line)
816
+ const globalMean = data.reduce((a,b)=>a+b,0)/n;
817
+ const gy = yOf1(globalMean);
818
+ ctx.strokeStyle='rgba(255,235,59,0.6)'; ctx.lineWidth=1; ctx.setLineDash([6,4]);
819
+ ctx.beginPath();ctx.moveTo(pad.l,gy);ctx.lineTo(pad.l+pw,gy);ctx.stroke();
820
+ ctx.setLineDash([]);
821
+ ctx.fillStyle='rgba(255,235,59,0.9)'; ctx.font='bold 9px Courier New'; ctx.textAlign='left';
822
+ ctx.fillText('\u03bc='+globalMean.toFixed(2), pad.l+4, gy-5);
823
+
824
+ // MA
825
+ const MA = Math.max(5, Math.min(30, Math.floor(n/10)));
826
+ const ma = data.map((_,i)=>{
827
+ const sl=data.slice(Math.max(0,i-MA+1),i+1);
828
+ return sl.reduce((a,b)=>a+b,0)/sl.length;
829
+ });
830
+
831
+ // Raw per-episode line (faint blue)
832
+ ctx.strokeStyle='#7eb8ff33'; ctx.lineWidth=1; ctx.beginPath();
833
+ data.forEach((v,i)=>{ i?ctx.lineTo(xOf(i),yOf1(v)):ctx.moveTo(xOf(i),yOf1(v)); });
834
+ ctx.stroke();
835
+
836
+ // Smoothed MA per-episode (bright blue)
837
+ ctx.strokeStyle='#7eb8ff'; ctx.lineWidth=2; ctx.beginPath();
838
+ ma.forEach((v,i)=>{ i?ctx.lineTo(xOf(i),yOf1(v)):ctx.moveTo(xOf(i),yOf1(v)); });
839
+ ctx.stroke();
840
+
841
+ // Cumulative net reward (orange, right axis)
842
+ if(cumul.length>=2){
843
+ ctx.strokeStyle='#ff9855'; ctx.lineWidth=2; ctx.beginPath();
844
+ cumul.forEach((v,i)=>{ i?ctx.lineTo(xOf(i),yOf2(v)):ctx.moveTo(xOf(i),yOf2(v)); });
845
+ ctx.stroke();
846
+ }
847
+
848
+ // X-axis
849
+ ctx.fillStyle='#445'; ctx.font='8px Courier New';
850
+ ctx.textAlign='left'; ctx.fillText('t=0', pad.l, h-4);
851
+ ctx.textAlign='right'; ctx.fillText('t='+n, pad.l+pw, h-4);
852
+
853
+ // Legend
854
+ ctx.font='8px Courier New'; ctx.textAlign='center';
855
+ ctx.fillStyle='#7eb8ff'; ctx.fillText('\u25CF per-ep', pad.l+pw*0.3, h-4);
856
+ ctx.fillStyle='#ff9855'; ctx.fillText('\u25CF net cumul.', pad.l+pw*0.7, h-4);
857
+
858
+ // Update title
859
+ const net = cumul.length? cumul[cumul.length-1] : 0;
860
+ const modeLabel = S.reward_mode.toUpperCase();
861
+ document.getElementById('rw-title').textContent =
862
+ 'EPISODE REWARD \u2014 '+modeLabel+' | \u03bc: '+globalMean.toFixed(2)+' | net: '+net.toFixed(2);
863
  }
864
 
865
  // ── Incident feed ────────────────────────────────────────────────────────
 
955
  // ── Render all ────────────────────────────────────────────────────────────
956
  function renderAll() {
957
  drawRoad();
958
+ drawRewardChart();
959
  drawLineChart('accChart', S.accuracy_history, '#4caf50', 'Episodes', 0, 100, false);
960
  renderFeed(S.incident_feed||[]);
961
  renderEpTable(S.episode_history||[]);
 
983
  S.episode_history.push(msg.data);
984
  if(!S.reward_history) S.reward_history=[];
985
  S.reward_history.push(msg.data.reward);
986
+ if(!S.cumulative_reward) S.cumulative_reward=[];
987
+ const prevNet = S.cumulative_reward.length? S.cumulative_reward[S.cumulative_reward.length-1] : 0;
988
+ S.cumulative_reward.push(prevNet + msg.data.reward);
989
  if(!S.accuracy_history) S.accuracy_history=[];
990
  S.accuracy_history.push(msg.data.accuracy||0);
991
  renderEpTable(S.episode_history);
992
+ drawRewardChart();
993
  drawLineChart('accChart', S.accuracy_history, '#4caf50', 'Episodes', 0, 100, false);
994
  } else if(msg.type==='tick'){
995
  Object.assign(S, msg.data);
server/requirements.txt CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  openenv-core[core]>=0.2.1
2
  fastapi>=0.115.0
3
  pydantic>=2.0.0
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch>=2.5.0
3
+ gymnasium>=0.29.0
4
  openenv-core[core]>=0.2.1
5
  fastapi>=0.115.0
6
  pydantic>=2.0.0