Spaces:
Runtime error
Runtime error
Commit Β·
236d9fb
1
Parent(s): 83cf44b
feat: integrate trained RecurrentPPO policy into Streamlit demo UI
Browse files- demo/streamlit_app.py +139 -22
demo/streamlit_app.py
CHANGED
|
@@ -70,6 +70,65 @@ def _get_preset_tasks(n: int = 8) -> list[str]:
|
|
| 70 |
|
| 71 |
PRESET_TASKS = _get_preset_tasks()
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
DARK = dict(
|
| 74 |
paper_bgcolor="rgba(0,0,0,0)",
|
| 75 |
plot_bgcolor="rgba(0,0,0,0)",
|
|
@@ -101,6 +160,10 @@ class Session:
|
|
| 101 |
self.obs_history: list[dict] = []
|
| 102 |
# Specialists auto-spawned for this episode
|
| 103 |
self.spawned_specialists: list[str] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
def boot(self):
|
| 106 |
if self.env is None:
|
|
@@ -123,6 +186,9 @@ class Session:
|
|
| 123 |
self.step_entropies = []
|
| 124 |
self.obs_history = []
|
| 125 |
self.spawned_specialists: list[str] = list(info.get("spawned_specialists", []))
|
|
|
|
|
|
|
|
|
|
| 126 |
return obs, info
|
| 127 |
|
| 128 |
def step(self, action):
|
|
@@ -133,6 +199,8 @@ class Session:
|
|
| 133 |
self.actions.append(info)
|
| 134 |
self.step_n += 1
|
| 135 |
self.done = term or trunc
|
|
|
|
|
|
|
| 136 |
|
| 137 |
# Capture step snapshot for replay
|
| 138 |
called = info.get("called_specialists", [])
|
|
@@ -887,12 +955,27 @@ def tab_live_demo():
|
|
| 887 |
reset_btn = c1.button("Reset Episode", type="primary", use_container_width=True, key="reset_btn")
|
| 888 |
run_btn = c2.button("Run Full Episode", use_container_width=True, key="run_btn")
|
| 889 |
st.markdown('<div style="height:6px"></div>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
cat = _load_catalog()
|
| 891 |
-
act_type = st.selectbox("Action type",
|
| 892 |
["RANDOM", "STOP", "CALL SPECIALIST", "PARALLEL SPAWN"],
|
| 893 |
-
key="act_type"
|
|
|
|
| 894 |
spec_ids = [sp["id"] for sp in cat]
|
| 895 |
-
spec_ch = st.selectbox("Target specialist", spec_ids, key="spec_ch"
|
|
|
|
| 896 |
step_btn = st.button("Execute One Step",
|
| 897 |
disabled=(S.env is None or S.done),
|
| 898 |
use_container_width=True, key="step_btn")
|
|
@@ -918,25 +1001,31 @@ def tab_live_demo():
|
|
| 918 |
|
| 919 |
# ββ Step βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 920 |
if step_btn and S.env is not None and not S.done:
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
if spec_ch in ids:
|
| 927 |
-
idx = ids.index(spec_ch)
|
| 928 |
-
if idx < S.env.max_specialists:
|
| 929 |
-
action[1 + idx] = 1.0
|
| 930 |
-
else:
|
| 931 |
-
action[1] = 1.0
|
| 932 |
-
elif act_type == "PARALLEL SPAWN":
|
| 933 |
-
action[0] = 6.0
|
| 934 |
-
action[1] = 1.0
|
| 935 |
-
if S.env.max_specialists > 1:
|
| 936 |
-
action[2] = 1.0
|
| 937 |
-
action[1 + S.env.max_specialists] = 1.0
|
| 938 |
else:
|
| 939 |
-
action = S.env.action_space.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 940 |
|
| 941 |
_, r, term, trunc, info = S.step(action)
|
| 942 |
done = term or trunc
|
|
@@ -962,7 +1051,14 @@ def tab_live_demo():
|
|
| 962 |
for _ in range(15):
|
| 963 |
if S.done:
|
| 964 |
break
|
| 965 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 966 |
# Use cumulative called_ids so graph stays populated even after STOP step
|
| 967 |
called = list(S.env.called_ids) if S.env else []
|
| 968 |
edges = [(e.caller_id, e.callee_id)
|
|
@@ -1194,6 +1290,27 @@ def tab_specialists():
|
|
| 1194 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1195 |
def tab_training():
|
| 1196 |
sec("Training Progress β Mean Reward per Episode")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1197 |
st.plotly_chart(fig_training_curve(), use_container_width=True)
|
| 1198 |
|
| 1199 |
sec("Policy Entropy β Action Confidence Over Training")
|
|
|
|
| 70 |
|
| 71 |
PRESET_TASKS = _get_preset_tasks()
|
| 72 |
|
| 73 |
+
HF_MODEL_REPO = "garvitsachdeva/spindleflow-rl"
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@st.cache_resource
|
| 77 |
+
def _load_trained_model(hf_repo: str):
|
| 78 |
+
"""Download RecurrentPPO + VecNormalize stats from HF Hub.
|
| 79 |
+
|
| 80 |
+
Returns (model, obs_mean, obs_var, clip_obs, error_str).
|
| 81 |
+
Temporarily lifts the HF_HUB_OFFLINE flag set at module level.
|
| 82 |
+
"""
|
| 83 |
+
import pickle
|
| 84 |
+
_old_hf = os.environ.pop("HF_HUB_OFFLINE", None)
|
| 85 |
+
_old_tf = os.environ.pop("TRANSFORMERS_OFFLINE", None)
|
| 86 |
+
try:
|
| 87 |
+
from huggingface_hub import hf_hub_download
|
| 88 |
+
from sb3_contrib import RecurrentPPO
|
| 89 |
+
|
| 90 |
+
model = RecurrentPPO.load(
|
| 91 |
+
hf_hub_download(hf_repo, "spindleflow_model.zip"), device="cpu"
|
| 92 |
+
)
|
| 93 |
+
obs_mean = obs_var = None
|
| 94 |
+
clip_obs = 10.0
|
| 95 |
+
try:
|
| 96 |
+
stats_path = hf_hub_download(hf_repo, "vec_normalize.pkl")
|
| 97 |
+
with open(stats_path, "rb") as f:
|
| 98 |
+
vn = pickle.load(f)
|
| 99 |
+
obs_mean = vn.obs_rms.mean.copy()
|
| 100 |
+
obs_var = vn.obs_rms.var.copy()
|
| 101 |
+
clip_obs = float(vn.clip_obs)
|
| 102 |
+
except Exception:
|
| 103 |
+
pass
|
| 104 |
+
return model, obs_mean, obs_var, clip_obs, None
|
| 105 |
+
except Exception as exc:
|
| 106 |
+
return None, None, None, 10.0, str(exc)
|
| 107 |
+
finally:
|
| 108 |
+
if _old_hf is not None:
|
| 109 |
+
os.environ["HF_HUB_OFFLINE"] = _old_hf
|
| 110 |
+
if _old_tf is not None:
|
| 111 |
+
os.environ["TRANSFORMERS_OFFLINE"] = _old_tf
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _predict(model, obs: np.ndarray, lstm_states, episode_starts,
|
| 115 |
+
obs_mean, obs_var, clip_obs: float):
|
| 116 |
+
"""Normalize obs and call model.predict(); return (action, new_lstm_states)."""
|
| 117 |
+
obs_arr = obs[np.newaxis, :].copy().astype(np.float32)
|
| 118 |
+
if obs_mean is not None and obs_var is not None:
|
| 119 |
+
obs_arr = np.clip(
|
| 120 |
+
(obs_arr - obs_mean) / np.sqrt(obs_var + 1e-8),
|
| 121 |
+
-clip_obs, clip_obs,
|
| 122 |
+
)
|
| 123 |
+
action_batch, new_states = model.predict(
|
| 124 |
+
obs_arr,
|
| 125 |
+
state=lstm_states,
|
| 126 |
+
episode_start=episode_starts,
|
| 127 |
+
deterministic=True,
|
| 128 |
+
)
|
| 129 |
+
return action_batch[0], new_states
|
| 130 |
+
|
| 131 |
+
|
| 132 |
DARK = dict(
|
| 133 |
paper_bgcolor="rgba(0,0,0,0)",
|
| 134 |
plot_bgcolor="rgba(0,0,0,0)",
|
|
|
|
| 160 |
self.obs_history: list[dict] = []
|
| 161 |
# Specialists auto-spawned for this episode
|
| 162 |
self.spawned_specialists: list[str] = []
|
| 163 |
+
# Trained policy inference state
|
| 164 |
+
self.obs_current: np.ndarray | None = None
|
| 165 |
+
self.lstm_states = None
|
| 166 |
+
self.episode_starts = np.array([True])
|
| 167 |
|
| 168 |
def boot(self):
|
| 169 |
if self.env is None:
|
|
|
|
| 186 |
self.step_entropies = []
|
| 187 |
self.obs_history = []
|
| 188 |
self.spawned_specialists: list[str] = list(info.get("spawned_specialists", []))
|
| 189 |
+
self.obs_current = obs
|
| 190 |
+
self.lstm_states = None
|
| 191 |
+
self.episode_starts = np.array([True])
|
| 192 |
return obs, info
|
| 193 |
|
| 194 |
def step(self, action):
|
|
|
|
| 199 |
self.actions.append(info)
|
| 200 |
self.step_n += 1
|
| 201 |
self.done = term or trunc
|
| 202 |
+
self.obs_current = obs
|
| 203 |
+
self.episode_starts = np.array([self.done])
|
| 204 |
|
| 205 |
# Capture step snapshot for replay
|
| 206 |
called = info.get("called_specialists", [])
|
|
|
|
| 955 |
reset_btn = c1.button("Reset Episode", type="primary", use_container_width=True, key="reset_btn")
|
| 956 |
run_btn = c2.button("Run Full Episode", use_container_width=True, key="run_btn")
|
| 957 |
st.markdown('<div style="height:6px"></div>', unsafe_allow_html=True)
|
| 958 |
+
|
| 959 |
+
use_trained = st.checkbox("π€ Use Trained Policy", value=False, key="use_trained",
|
| 960 |
+
help="Load the trained RecurrentPPO model from HF Hub")
|
| 961 |
+
trained_model = obs_mean = obs_var = None
|
| 962 |
+
clip_obs = 10.0
|
| 963 |
+
if use_trained:
|
| 964 |
+
with st.spinner("Loading trained model from HF Hubβ¦"):
|
| 965 |
+
trained_model, obs_mean, obs_var, clip_obs, model_err = _load_trained_model(HF_MODEL_REPO)
|
| 966 |
+
if model_err:
|
| 967 |
+
st.error(f"Model load failed: {model_err}")
|
| 968 |
+
else:
|
| 969 |
+
st.success("Trained policy loaded β")
|
| 970 |
+
|
| 971 |
cat = _load_catalog()
|
| 972 |
+
act_type = st.selectbox("Action type (manual mode)",
|
| 973 |
["RANDOM", "STOP", "CALL SPECIALIST", "PARALLEL SPAWN"],
|
| 974 |
+
key="act_type",
|
| 975 |
+
disabled=use_trained)
|
| 976 |
spec_ids = [sp["id"] for sp in cat]
|
| 977 |
+
spec_ch = st.selectbox("Target specialist", spec_ids, key="spec_ch",
|
| 978 |
+
disabled=use_trained)
|
| 979 |
step_btn = st.button("Execute One Step",
|
| 980 |
disabled=(S.env is None or S.done),
|
| 981 |
use_container_width=True, key="step_btn")
|
|
|
|
| 1001 |
|
| 1002 |
# ββ Step βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1003 |
if step_btn and S.env is not None and not S.done:
|
| 1004 |
+
if use_trained and trained_model is not None and S.obs_current is not None:
|
| 1005 |
+
action, S.lstm_states = _predict(
|
| 1006 |
+
trained_model, S.obs_current, S.lstm_states,
|
| 1007 |
+
S.episode_starts, obs_mean, obs_var, clip_obs,
|
| 1008 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
else:
|
| 1010 |
+
action = np.zeros(S.env.action_space.shape, dtype=np.float32)
|
| 1011 |
+
if act_type == "STOP":
|
| 1012 |
+
action[0] = 1.0
|
| 1013 |
+
elif act_type == "CALL SPECIALIST":
|
| 1014 |
+
ids = S.registry.list_ids()
|
| 1015 |
+
if spec_ch in ids:
|
| 1016 |
+
idx = ids.index(spec_ch)
|
| 1017 |
+
if idx < S.env.max_specialists:
|
| 1018 |
+
action[1 + idx] = 1.0
|
| 1019 |
+
else:
|
| 1020 |
+
action[1] = 1.0
|
| 1021 |
+
elif act_type == "PARALLEL SPAWN":
|
| 1022 |
+
action[0] = 6.0
|
| 1023 |
+
action[1] = 1.0
|
| 1024 |
+
if S.env.max_specialists > 1:
|
| 1025 |
+
action[2] = 1.0
|
| 1026 |
+
action[1 + S.env.max_specialists] = 1.0
|
| 1027 |
+
else:
|
| 1028 |
+
action = S.env.action_space.sample()
|
| 1029 |
|
| 1030 |
_, r, term, trunc, info = S.step(action)
|
| 1031 |
done = term or trunc
|
|
|
|
| 1051 |
for _ in range(15):
|
| 1052 |
if S.done:
|
| 1053 |
break
|
| 1054 |
+
if use_trained and trained_model is not None and S.obs_current is not None:
|
| 1055 |
+
action, S.lstm_states = _predict(
|
| 1056 |
+
trained_model, S.obs_current, S.lstm_states,
|
| 1057 |
+
S.episode_starts, obs_mean, obs_var, clip_obs,
|
| 1058 |
+
)
|
| 1059 |
+
else:
|
| 1060 |
+
action = S.env.action_space.sample()
|
| 1061 |
+
_, _, _, _, info = S.step(action)
|
| 1062 |
# Use cumulative called_ids so graph stays populated even after STOP step
|
| 1063 |
called = list(S.env.called_ids) if S.env else []
|
| 1064 |
edges = [(e.caller_id, e.callee_id)
|
|
|
|
| 1290 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1291 |
def tab_training():
|
| 1292 |
sec("Training Progress β Mean Reward per Episode")
|
| 1293 |
+
|
| 1294 |
+
c_fetch, _ = st.columns([2, 5])
|
| 1295 |
+
if c_fetch.button("π₯ Fetch latest curve from HF Hub", key="fetch_curve"):
|
| 1296 |
+
_old_hf = os.environ.pop("HF_HUB_OFFLINE", None)
|
| 1297 |
+
_old_tf = os.environ.pop("TRANSFORMERS_OFFLINE", None)
|
| 1298 |
+
try:
|
| 1299 |
+
import shutil
|
| 1300 |
+
from huggingface_hub import hf_hub_download
|
| 1301 |
+
src = hf_hub_download(HF_MODEL_REPO, "reward_curve.json")
|
| 1302 |
+
ASSETS.mkdir(parents=True, exist_ok=True)
|
| 1303 |
+
shutil.copy(src, ASSETS / "reward_curve.json")
|
| 1304 |
+
st.success("reward_curve.json updated β chart will refresh.")
|
| 1305 |
+
st.cache_data.clear()
|
| 1306 |
+
except Exception as exc:
|
| 1307 |
+
st.error(f"Download failed: {exc}")
|
| 1308 |
+
finally:
|
| 1309 |
+
if _old_hf is not None:
|
| 1310 |
+
os.environ["HF_HUB_OFFLINE"] = _old_hf
|
| 1311 |
+
if _old_tf is not None:
|
| 1312 |
+
os.environ["TRANSFORMERS_OFFLINE"] = _old_tf
|
| 1313 |
+
|
| 1314 |
st.plotly_chart(fig_training_curve(), use_container_width=True)
|
| 1315 |
|
| 1316 |
sec("Policy Entropy β Action Confidence Over Training")
|