Spaces:
Runtime error
Runtime error
Commit ·
8e0fa29
1
Parent(s): 278b691
Add HF Hub push cell + fix pip deps in Colab training script
Browse filesCo-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- colab/train_colab.py +98 -0
colab/train_colab.py
CHANGED
|
@@ -29,6 +29,7 @@ subprocess.run([
|
|
| 29 |
"openenv", "stable-baselines3", "sb3-contrib", "gymnasium",
|
| 30 |
"sentence-transformers", "openai", "pyyaml", "trl",
|
| 31 |
"transformers", "datasets", "torch",
|
|
|
|
| 32 |
], check=True)
|
| 33 |
print("Packages OK")
|
| 34 |
|
|
@@ -395,3 +396,100 @@ else:
|
|
| 395 |
print("\n" + "="*55)
|
| 396 |
print("All learning features verified. Ready for final checkpoint.")
|
| 397 |
print("="*55)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"openenv", "stable-baselines3", "sb3-contrib", "gymnasium",
|
| 30 |
"sentence-transformers", "openai", "pyyaml", "trl",
|
| 31 |
"transformers", "datasets", "torch",
|
| 32 |
+
"matplotlib", "audioop-lts", "huggingface_hub",
|
| 33 |
], check=True)
|
| 34 |
print("Packages OK")
|
| 35 |
|
|
|
|
| 396 |
print("\n" + "="*55)
|
| 397 |
print("All learning features verified. Ready for final checkpoint.")
|
| 398 |
print("="*55)
|
| 399 |
+
|
| 400 |
+
# ============================================================
|
| 401 |
+
# CELL 8 — Push trained model + artifacts to HuggingFace Hub
|
| 402 |
+
#
|
| 403 |
+
# Requires HF_TOKEN secret set in Colab:
|
| 404 |
+
# Runtime > Manage secrets (key icon in left sidebar)
|
| 405 |
+
# Name: HF_TOKEN Value: hf_xxxxx (write token from hf.co/settings/tokens)
|
| 406 |
+
#
|
| 407 |
+
# Target repo: garvitsachdeva/spindleflow-rl
|
| 408 |
+
# ============================================================
|
| 409 |
+
import numpy as np
|
| 410 |
+
from huggingface_hub import HfApi, CommitOperationAdd
|
| 411 |
+
from google.colab import userdata
|
| 412 |
+
|
| 413 |
+
HF_TOKEN = userdata.get("HF_TOKEN")
|
| 414 |
+
if not HF_TOKEN:
|
| 415 |
+
raise RuntimeError("HF_TOKEN not set. Go to Runtime > Manage secrets and add it.")
|
| 416 |
+
|
| 417 |
+
HF_REPO = "garvitsachdeva/spindleflow-rl"
|
| 418 |
+
api = HfApi(token=HF_TOKEN)
|
| 419 |
+
_repo_name = HF_REPO.split("/")[-1]
|
| 420 |
+
|
| 421 |
+
print(f"Pushing to https://huggingface.co/{HF_REPO} ...")
|
| 422 |
+
api.create_repo(repo_id=_repo_name, repo_type="model", exist_ok=True)
|
| 423 |
+
|
| 424 |
+
ep = reward_logger.episode_rewards
|
| 425 |
+
f5 = float(np.mean(ep[:5])) if len(ep) >= 5 else 0.0
|
| 426 |
+
l5 = float(np.mean(ep[-5:])) if len(ep) >= 5 else 0.0
|
| 427 |
+
total_steps_run = int(_cfg.get("training", {}).get("total_timesteps", 500_000))
|
| 428 |
+
|
| 429 |
+
readme_text = f"""---
|
| 430 |
+
license: mit
|
| 431 |
+
tags:
|
| 432 |
+
- reinforcement-learning
|
| 433 |
+
- stable-baselines3
|
| 434 |
+
- sb3-contrib
|
| 435 |
+
- gymnasium
|
| 436 |
+
- multi-agent
|
| 437 |
+
- openenv
|
| 438 |
+
library_name: stable-baselines3
|
| 439 |
+
---
|
| 440 |
+
|
| 441 |
+
# SpindleFlow RL — Delegation Policy
|
| 442 |
+
|
| 443 |
+
LSTM PPO agent trained on SpindleFlow-v0 (OpenEnv).
|
| 444 |
+
|
| 445 |
+
## Training summary
|
| 446 |
+
| Metric | Value |
|
| 447 |
+
|---|---|
|
| 448 |
+
| Algorithm | RecurrentPPO (SB3 + sb3-contrib) |
|
| 449 |
+
| Total timesteps | {total_steps_run:,} |
|
| 450 |
+
| Episodes completed | {len(ep)} |
|
| 451 |
+
| First-5 mean reward | {f5:.4f} |
|
| 452 |
+
| Last-5 mean reward | {l5:.4f} |
|
| 453 |
+
| Improvement | {l5 - f5:+.4f} |
|
| 454 |
+
|
| 455 |
+

|
| 456 |
+
|
| 457 |
+
## Load
|
| 458 |
+
```python
|
| 459 |
+
from sb3_contrib import RecurrentPPO
|
| 460 |
+
from huggingface_hub import hf_hub_download
|
| 461 |
+
model = RecurrentPPO.load(hf_hub_download("{HF_REPO}", "spindleflow_model.zip"))
|
| 462 |
+
```
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
readme_path = "/content/README_model.md"
|
| 466 |
+
with open(readme_path, "w") as f:
|
| 467 |
+
f.write(readme_text)
|
| 468 |
+
|
| 469 |
+
candidates = [
|
| 470 |
+
("/content/spindleflow_colab_demo.zip", "spindleflow_model.zip"),
|
| 471 |
+
("/content/vec_normalize_colab.pkl", "vec_normalize.pkl"),
|
| 472 |
+
("/content/reward_curve.png", "reward_curve.png"),
|
| 473 |
+
("/content/demo/assets/reward_curve.json", "reward_curve.json"),
|
| 474 |
+
(readme_path, "README.md"),
|
| 475 |
+
]
|
| 476 |
+
|
| 477 |
+
ops = [
|
| 478 |
+
CommitOperationAdd(path_in_repo=dst, path_or_fileobj=src)
|
| 479 |
+
for src, dst in candidates
|
| 480 |
+
if os.path.exists(src)
|
| 481 |
+
]
|
| 482 |
+
|
| 483 |
+
api.create_commit(
|
| 484 |
+
repo_id=HF_REPO,
|
| 485 |
+
repo_type="model",
|
| 486 |
+
operations=ops,
|
| 487 |
+
commit_message="Add trained SpindleFlow RL policy (Colab T4)",
|
| 488 |
+
token=HF_TOKEN,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
print(f"Uploaded {len(ops)} files.")
|
| 492 |
+
print(f"Model live at: https://huggingface.co/{HF_REPO}")
|
| 493 |
+
print(f"First-5 mean reward : {f5:.4f}")
|
| 494 |
+
print(f"Last-5 mean reward : {l5:.4f}")
|
| 495 |
+
print(f"Improvement : {l5 - f5:+.4f}")
|