garvitsachdeva commited on
Commit
236d9fb
Β·
1 Parent(s): 83cf44b

feat: integrate trained RecurrentPPO policy into Streamlit demo UI

Browse files
Files changed (1) hide show
  1. demo/streamlit_app.py +139 -22
demo/streamlit_app.py CHANGED
@@ -70,6 +70,65 @@ def _get_preset_tasks(n: int = 8) -> list[str]:
70
 
71
  PRESET_TASKS = _get_preset_tasks()
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  DARK = dict(
74
  paper_bgcolor="rgba(0,0,0,0)",
75
  plot_bgcolor="rgba(0,0,0,0)",
@@ -101,6 +160,10 @@ class Session:
101
  self.obs_history: list[dict] = []
102
  # Specialists auto-spawned for this episode
103
  self.spawned_specialists: list[str] = []
 
 
 
 
104
 
105
  def boot(self):
106
  if self.env is None:
@@ -123,6 +186,9 @@ class Session:
123
  self.step_entropies = []
124
  self.obs_history = []
125
  self.spawned_specialists: list[str] = list(info.get("spawned_specialists", []))
 
 
 
126
  return obs, info
127
 
128
  def step(self, action):
@@ -133,6 +199,8 @@ class Session:
133
  self.actions.append(info)
134
  self.step_n += 1
135
  self.done = term or trunc
 
 
136
 
137
  # Capture step snapshot for replay
138
  called = info.get("called_specialists", [])
@@ -887,12 +955,27 @@ def tab_live_demo():
887
  reset_btn = c1.button("Reset Episode", type="primary", use_container_width=True, key="reset_btn")
888
  run_btn = c2.button("Run Full Episode", use_container_width=True, key="run_btn")
889
  st.markdown('<div style="height:6px"></div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
890
  cat = _load_catalog()
891
- act_type = st.selectbox("Action type",
892
  ["RANDOM", "STOP", "CALL SPECIALIST", "PARALLEL SPAWN"],
893
- key="act_type")
 
894
  spec_ids = [sp["id"] for sp in cat]
895
- spec_ch = st.selectbox("Target specialist", spec_ids, key="spec_ch")
 
896
  step_btn = st.button("Execute One Step",
897
  disabled=(S.env is None or S.done),
898
  use_container_width=True, key="step_btn")
@@ -918,25 +1001,31 @@ def tab_live_demo():
918
 
919
  # ── Step ───────────────────────────────────────────────
920
  if step_btn and S.env is not None and not S.done:
921
- action = np.zeros(S.env.action_space.shape, dtype=np.float32)
922
- if act_type == "STOP":
923
- action[0] = 1.0
924
- elif act_type == "CALL SPECIALIST":
925
- ids = S.registry.list_ids()
926
- if spec_ch in ids:
927
- idx = ids.index(spec_ch)
928
- if idx < S.env.max_specialists:
929
- action[1 + idx] = 1.0
930
- else:
931
- action[1] = 1.0
932
- elif act_type == "PARALLEL SPAWN":
933
- action[0] = 6.0
934
- action[1] = 1.0
935
- if S.env.max_specialists > 1:
936
- action[2] = 1.0
937
- action[1 + S.env.max_specialists] = 1.0
938
  else:
939
- action = S.env.action_space.sample()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
940
 
941
  _, r, term, trunc, info = S.step(action)
942
  done = term or trunc
@@ -962,7 +1051,14 @@ def tab_live_demo():
962
  for _ in range(15):
963
  if S.done:
964
  break
965
- _, _, _, _, info = S.step(S.env.action_space.sample())
 
 
 
 
 
 
 
966
  # Use cumulative called_ids so graph stays populated even after STOP step
967
  called = list(S.env.called_ids) if S.env else []
968
  edges = [(e.caller_id, e.callee_id)
@@ -1194,6 +1290,27 @@ def tab_specialists():
1194
  # ─────────────────────────────────────────────────────────
1195
  def tab_training():
1196
  sec("Training Progress β€” Mean Reward per Episode")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1197
  st.plotly_chart(fig_training_curve(), use_container_width=True)
1198
 
1199
  sec("Policy Entropy β€” Action Confidence Over Training")
 
70
 
71
  PRESET_TASKS = _get_preset_tasks()
72
 
73
+ HF_MODEL_REPO = "garvitsachdeva/spindleflow-rl"
74
+
75
+
76
+ @st.cache_resource
77
+ def _load_trained_model(hf_repo: str):
78
+ """Download RecurrentPPO + VecNormalize stats from HF Hub.
79
+
80
+ Returns (model, obs_mean, obs_var, clip_obs, error_str).
81
+ Temporarily lifts the HF_HUB_OFFLINE flag set at module level.
82
+ """
83
+ import pickle
84
+ _old_hf = os.environ.pop("HF_HUB_OFFLINE", None)
85
+ _old_tf = os.environ.pop("TRANSFORMERS_OFFLINE", None)
86
+ try:
87
+ from huggingface_hub import hf_hub_download
88
+ from sb3_contrib import RecurrentPPO
89
+
90
+ model = RecurrentPPO.load(
91
+ hf_hub_download(hf_repo, "spindleflow_model.zip"), device="cpu"
92
+ )
93
+ obs_mean = obs_var = None
94
+ clip_obs = 10.0
95
+ try:
96
+ stats_path = hf_hub_download(hf_repo, "vec_normalize.pkl")
97
+ with open(stats_path, "rb") as f:
98
+ vn = pickle.load(f)
99
+ obs_mean = vn.obs_rms.mean.copy()
100
+ obs_var = vn.obs_rms.var.copy()
101
+ clip_obs = float(vn.clip_obs)
102
+ except Exception:
103
+ pass
104
+ return model, obs_mean, obs_var, clip_obs, None
105
+ except Exception as exc:
106
+ return None, None, None, 10.0, str(exc)
107
+ finally:
108
+ if _old_hf is not None:
109
+ os.environ["HF_HUB_OFFLINE"] = _old_hf
110
+ if _old_tf is not None:
111
+ os.environ["TRANSFORMERS_OFFLINE"] = _old_tf
112
+
113
+
114
+ def _predict(model, obs: np.ndarray, lstm_states, episode_starts,
115
+ obs_mean, obs_var, clip_obs: float):
116
+ """Normalize obs and call model.predict(); return (action, new_lstm_states)."""
117
+ obs_arr = obs[np.newaxis, :].copy().astype(np.float32)
118
+ if obs_mean is not None and obs_var is not None:
119
+ obs_arr = np.clip(
120
+ (obs_arr - obs_mean) / np.sqrt(obs_var + 1e-8),
121
+ -clip_obs, clip_obs,
122
+ )
123
+ action_batch, new_states = model.predict(
124
+ obs_arr,
125
+ state=lstm_states,
126
+ episode_start=episode_starts,
127
+ deterministic=True,
128
+ )
129
+ return action_batch[0], new_states
130
+
131
+
132
  DARK = dict(
133
  paper_bgcolor="rgba(0,0,0,0)",
134
  plot_bgcolor="rgba(0,0,0,0)",
 
160
  self.obs_history: list[dict] = []
161
  # Specialists auto-spawned for this episode
162
  self.spawned_specialists: list[str] = []
163
+ # Trained policy inference state
164
+ self.obs_current: np.ndarray | None = None
165
+ self.lstm_states = None
166
+ self.episode_starts = np.array([True])
167
 
168
  def boot(self):
169
  if self.env is None:
 
186
  self.step_entropies = []
187
  self.obs_history = []
188
  self.spawned_specialists: list[str] = list(info.get("spawned_specialists", []))
189
+ self.obs_current = obs
190
+ self.lstm_states = None
191
+ self.episode_starts = np.array([True])
192
  return obs, info
193
 
194
  def step(self, action):
 
199
  self.actions.append(info)
200
  self.step_n += 1
201
  self.done = term or trunc
202
+ self.obs_current = obs
203
+ self.episode_starts = np.array([self.done])
204
 
205
  # Capture step snapshot for replay
206
  called = info.get("called_specialists", [])
 
955
  reset_btn = c1.button("Reset Episode", type="primary", use_container_width=True, key="reset_btn")
956
  run_btn = c2.button("Run Full Episode", use_container_width=True, key="run_btn")
957
  st.markdown('<div style="height:6px"></div>', unsafe_allow_html=True)
958
+
959
+ use_trained = st.checkbox("πŸ€– Use Trained Policy", value=False, key="use_trained",
960
+ help="Load the trained RecurrentPPO model from HF Hub")
961
+ trained_model = obs_mean = obs_var = None
962
+ clip_obs = 10.0
963
+ if use_trained:
964
+ with st.spinner("Loading trained model from HF Hub…"):
965
+ trained_model, obs_mean, obs_var, clip_obs, model_err = _load_trained_model(HF_MODEL_REPO)
966
+ if model_err:
967
+ st.error(f"Model load failed: {model_err}")
968
+ else:
969
+ st.success("Trained policy loaded βœ“")
970
+
971
  cat = _load_catalog()
972
+ act_type = st.selectbox("Action type (manual mode)",
973
  ["RANDOM", "STOP", "CALL SPECIALIST", "PARALLEL SPAWN"],
974
+ key="act_type",
975
+ disabled=use_trained)
976
  spec_ids = [sp["id"] for sp in cat]
977
+ spec_ch = st.selectbox("Target specialist", spec_ids, key="spec_ch",
978
+ disabled=use_trained)
979
  step_btn = st.button("Execute One Step",
980
  disabled=(S.env is None or S.done),
981
  use_container_width=True, key="step_btn")
 
1001
 
1002
  # ── Step ───────────────────────────────────────────────
1003
  if step_btn and S.env is not None and not S.done:
1004
+ if use_trained and trained_model is not None and S.obs_current is not None:
1005
+ action, S.lstm_states = _predict(
1006
+ trained_model, S.obs_current, S.lstm_states,
1007
+ S.episode_starts, obs_mean, obs_var, clip_obs,
1008
+ )
 
 
 
 
 
 
 
 
 
 
 
 
1009
  else:
1010
+ action = np.zeros(S.env.action_space.shape, dtype=np.float32)
1011
+ if act_type == "STOP":
1012
+ action[0] = 1.0
1013
+ elif act_type == "CALL SPECIALIST":
1014
+ ids = S.registry.list_ids()
1015
+ if spec_ch in ids:
1016
+ idx = ids.index(spec_ch)
1017
+ if idx < S.env.max_specialists:
1018
+ action[1 + idx] = 1.0
1019
+ else:
1020
+ action[1] = 1.0
1021
+ elif act_type == "PARALLEL SPAWN":
1022
+ action[0] = 6.0
1023
+ action[1] = 1.0
1024
+ if S.env.max_specialists > 1:
1025
+ action[2] = 1.0
1026
+ action[1 + S.env.max_specialists] = 1.0
1027
+ else:
1028
+ action = S.env.action_space.sample()
1029
 
1030
  _, r, term, trunc, info = S.step(action)
1031
  done = term or trunc
 
1051
  for _ in range(15):
1052
  if S.done:
1053
  break
1054
+ if use_trained and trained_model is not None and S.obs_current is not None:
1055
+ action, S.lstm_states = _predict(
1056
+ trained_model, S.obs_current, S.lstm_states,
1057
+ S.episode_starts, obs_mean, obs_var, clip_obs,
1058
+ )
1059
+ else:
1060
+ action = S.env.action_space.sample()
1061
+ _, _, _, _, info = S.step(action)
1062
  # Use cumulative called_ids so graph stays populated even after STOP step
1063
  called = list(S.env.called_ids) if S.env else []
1064
  edges = [(e.caller_id, e.callee_id)
 
1290
  # ─────────────────────────────────────────────────────────
1291
  def tab_training():
1292
  sec("Training Progress β€” Mean Reward per Episode")
1293
+
1294
+ c_fetch, _ = st.columns([2, 5])
1295
+ if c_fetch.button("πŸ“₯ Fetch latest curve from HF Hub", key="fetch_curve"):
1296
+ _old_hf = os.environ.pop("HF_HUB_OFFLINE", None)
1297
+ _old_tf = os.environ.pop("TRANSFORMERS_OFFLINE", None)
1298
+ try:
1299
+ import shutil
1300
+ from huggingface_hub import hf_hub_download
1301
+ src = hf_hub_download(HF_MODEL_REPO, "reward_curve.json")
1302
+ ASSETS.mkdir(parents=True, exist_ok=True)
1303
+ shutil.copy(src, ASSETS / "reward_curve.json")
1304
+ st.success("reward_curve.json updated β€” chart will refresh.")
1305
+ st.cache_data.clear()
1306
+ except Exception as exc:
1307
+ st.error(f"Download failed: {exc}")
1308
+ finally:
1309
+ if _old_hf is not None:
1310
+ os.environ["HF_HUB_OFFLINE"] = _old_hf
1311
+ if _old_tf is not None:
1312
+ os.environ["TRANSFORMERS_OFFLINE"] = _old_tf
1313
+
1314
  st.plotly_chart(fig_training_curve(), use_container_width=True)
1315
 
1316
  sec("Policy Entropy β€” Action Confidence Over Training")