garvitsachdeva Claude Sonnet 4.6 commited on
Commit
8e0fa29
·
1 Parent(s): 278b691

Add HF Hub push cell + fix pip deps in Colab training script

Browse files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. colab/train_colab.py +98 -0
colab/train_colab.py CHANGED
@@ -29,6 +29,7 @@ subprocess.run([
29
  "openenv", "stable-baselines3", "sb3-contrib", "gymnasium",
30
  "sentence-transformers", "openai", "pyyaml", "trl",
31
  "transformers", "datasets", "torch",
 
32
  ], check=True)
33
  print("Packages OK")
34
 
@@ -395,3 +396,100 @@ else:
395
  print("\n" + "="*55)
396
  print("All learning features verified. Ready for final checkpoint.")
397
  print("="*55)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  "openenv", "stable-baselines3", "sb3-contrib", "gymnasium",
30
  "sentence-transformers", "openai", "pyyaml", "trl",
31
  "transformers", "datasets", "torch",
32
+ "matplotlib", "audioop-lts", "huggingface_hub",
33
  ], check=True)
34
  print("Packages OK")
35
 
 
396
  print("\n" + "="*55)
397
  print("All learning features verified. Ready for final checkpoint.")
398
  print("="*55)
399
+
400
+ # ============================================================
401
+ # CELL 8 — Push trained model + artifacts to HuggingFace Hub
402
+ #
403
+ # Requires HF_TOKEN secret set in Colab:
404
+ # Runtime > Manage secrets (key icon in left sidebar)
405
+ # Name: HF_TOKEN Value: hf_xxxxx (write token from hf.co/settings/tokens)
406
+ #
407
+ # Target repo: garvitsachdeva/spindleflow-rl
408
+ # ============================================================
409
+ import numpy as np
410
+ from huggingface_hub import HfApi, CommitOperationAdd
411
+ from google.colab import userdata
412
+
413
+ HF_TOKEN = userdata.get("HF_TOKEN")
414
+ if not HF_TOKEN:
415
+ raise RuntimeError("HF_TOKEN not set. Go to Runtime > Manage secrets and add it.")
416
+
417
+ HF_REPO = "garvitsachdeva/spindleflow-rl"
418
+ api = HfApi(token=HF_TOKEN)
419
+ _repo_name = HF_REPO.split("/")[-1]
420
+
421
+ print(f"Pushing to https://huggingface.co/{HF_REPO} ...")
422
+ api.create_repo(repo_id=_repo_name, repo_type="model", exist_ok=True)
423
+
424
+ ep = reward_logger.episode_rewards
425
+ f5 = float(np.mean(ep[:5])) if len(ep) >= 5 else 0.0
426
+ l5 = float(np.mean(ep[-5:])) if len(ep) >= 5 else 0.0
427
+ total_steps_run = int(_cfg.get("training", {}).get("total_timesteps", 500_000))
428
+
429
+ readme_text = f"""---
430
+ license: mit
431
+ tags:
432
+ - reinforcement-learning
433
+ - stable-baselines3
434
+ - sb3-contrib
435
+ - gymnasium
436
+ - multi-agent
437
+ - openenv
438
+ library_name: stable-baselines3
439
+ ---
440
+
441
+ # SpindleFlow RL — Delegation Policy
442
+
443
+ LSTM PPO agent trained on SpindleFlow-v0 (OpenEnv).
444
+
445
+ ## Training summary
446
+ | Metric | Value |
447
+ |---|---|
448
+ | Algorithm | RecurrentPPO (SB3 + sb3-contrib) |
449
+ | Total timesteps | {total_steps_run:,} |
450
+ | Episodes completed | {len(ep)} |
451
+ | First-5 mean reward | {f5:.4f} |
452
+ | Last-5 mean reward | {l5:.4f} |
453
+ | Improvement | {l5 - f5:+.4f} |
454
+
455
+ ![Reward Curve](reward_curve.png)
456
+
457
+ ## Load
458
+ ```python
459
+ from sb3_contrib import RecurrentPPO
460
+ from huggingface_hub import hf_hub_download
461
+ model = RecurrentPPO.load(hf_hub_download("{HF_REPO}", "spindleflow_model.zip"))
462
+ ```
463
+ """
464
+
465
+ readme_path = "/content/README_model.md"
466
+ with open(readme_path, "w") as f:
467
+ f.write(readme_text)
468
+
469
+ candidates = [
470
+ ("/content/spindleflow_colab_demo.zip", "spindleflow_model.zip"),
471
+ ("/content/vec_normalize_colab.pkl", "vec_normalize.pkl"),
472
+ ("/content/reward_curve.png", "reward_curve.png"),
473
+ ("/content/demo/assets/reward_curve.json", "reward_curve.json"),
474
+ (readme_path, "README.md"),
475
+ ]
476
+
477
+ ops = [
478
+ CommitOperationAdd(path_in_repo=dst, path_or_fileobj=src)
479
+ for src, dst in candidates
480
+ if os.path.exists(src)
481
+ ]
482
+
483
+ api.create_commit(
484
+ repo_id=HF_REPO,
485
+ repo_type="model",
486
+ operations=ops,
487
+ commit_message="Add trained SpindleFlow RL policy (Colab T4)",
488
+ token=HF_TOKEN,
489
+ )
490
+
491
+ print(f"Uploaded {len(ops)} files.")
492
+ print(f"Model live at: https://huggingface.co/{HF_REPO}")
493
+ print(f"First-5 mean reward : {f5:.4f}")
494
+ print(f"Last-5 mean reward : {l5:.4f}")
495
+ print(f"Improvement : {l5 - f5:+.4f}")