Spaces:
Sleeping
Sleeping
| """ | |
| Modal app for running Sakha GRPO training. | |
| Usage: | |
| modal run scripts/modal_train.py --mode demo --task hard --episodes 50 | |
| modal run scripts/modal_train.py --mode smoke | |
| """ | |
| import modal | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.12") | |
| .apt_install("git") | |
| .pip_install("uv") | |
| .run_commands( | |
| "uv pip install --system 'transformers>=5.2.0' 'trl[quantization]' peft accelerate bitsandbytes jmespath wandb", | |
| "uv pip install --system 'git+https://github.com/meta-pytorch/OpenEnv.git' pydantic fastapi uvicorn openai python-dotenv tenacity", | |
| ) | |
| .add_local_dir( | |
| "/home/verma/projects/sakha", | |
| remote_path="/sakha", | |
| ignore=[ | |
| ".git", | |
| ".venv", | |
| "__pycache__", | |
| ".pytest_cache", | |
| ".ruff_cache", | |
| "artifacts", | |
| ".sisyphus", | |
| "*.pyc", | |
| "*.pyo", | |
| ], | |
| ) | |
| ) | |
| volume = modal.Volume.from_name("sakha-training", create_if_missing=True) | |
| app = modal.App("sakha-grpo-training", image=image) | |
| def run_training( | |
| mode: str = "demo", | |
| task: str = "hard", | |
| episodes: int | None = None, | |
| max_steps: int = 96, | |
| model: str = "Qwen/Qwen3-0.6B", | |
| seed: int = 42, | |
| ) -> dict: | |
| """Run GRPO training on Modal.""" | |
| import os | |
| import subprocess | |
| import json | |
| import sys | |
| from pathlib import Path | |
| os.environ["PYTHONPATH"] = "/sakha/src" | |
| sys.path.insert(0, "/sakha/src") | |
| install = subprocess.run( | |
| ["pip", "install", "--no-deps", "-e", "/sakha"], | |
| capture_output=True, | |
| text=True, | |
| ) | |
| if install.returncode != 0: | |
| subprocess.run(["pip", "install", "-e", "/sakha"], capture_output=True, text=True) | |
| output_dir = "/artifacts/grpo" | |
| cmd = [ | |
| "python", | |
| "/sakha/scripts/train_grpo.py", | |
| "--mode", | |
| mode, | |
| "--task", | |
| task, | |
| "--model", | |
| model, | |
| "--max-steps", | |
| str(max_steps), | |
| "--seed", | |
| str(seed), | |
| "--output-dir", | |
| output_dir, | |
| ] | |
| if episodes is not None: | |
| cmd.extend(["--episodes", str(episodes)]) | |
| env = os.environ.copy() | |
| env["PYTHONPATH"] = "/sakha/src" | |
| env["TRL_EXPERIMENTAL_SILENCE"] = "1" | |
| env["WANDB_MODE"] = "disabled" | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| cwd="/sakha", | |
| env=env, | |
| ) | |
| output = { | |
| "stdout": result.stdout, | |
| "stderr": result.stderr, | |
| "returncode": result.returncode, | |
| "success": result.returncode == 0, | |
| } | |
| results_files = list(Path(output_dir).rglob("results.json")) | |
| if results_files: | |
| latest = max(results_files, key=lambda p: p.stat().st_mtime) | |
| output["results_file"] = str(latest) | |
| output["results"] = json.loads(latest.read_text()) | |
| checkpoints = list(Path(output_dir).rglob("checkpoint-*")) | |
| if checkpoints: | |
| output["checkpoints"] = [str(c) for c in checkpoints] | |
| return output | |
| def main( | |
| mode: str = "demo", | |
| task: str = "hard", | |
| episodes: int | None = None, | |
| max_steps: int = 96, | |
| model: str = "Qwen/Qwen3-0.6B", | |
| seed: int = 42, | |
| ): | |
| print(f"Running Sakha GRPO training: mode={mode}, task={task}, model={model}") | |
| result = run_training.remote( | |
| mode=mode, | |
| task=task, | |
| episodes=episodes, | |
| max_steps=max_steps, | |
| model=model, | |
| seed=seed, | |
| ) | |
| print(f"\nExit code: {result['returncode']}") | |
| print(f"Success: {result['success']}") | |
| print("\n--- STDOUT ---") | |
| print(result["stdout"]) | |
| if result["stderr"]: | |
| print("\n--- STDERR ---") | |
| print(result["stderr"]) | |
| if "results" in result: | |
| print("\n--- RESULTS ---") | |
| print(result["results"]) | |
| if not result["success"]: | |
| raise RuntimeError("Training failed!") | |
| print("\nTraining completed!") | |