sakha / scripts /modal_quick_train.py
atharva-again's picture
chore: pre-merge cleanup for colab-training branch
509d302 unverified
"""
Quick GRPO training run for Sakha.
Runs small episode count and writes metrics to JSONL for plotting.
Usage:
modal run scripts/modal_quick_train.py --episodes 10 --max-steps 12
"""
import modal
image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("git")
.pip_install("uv")
.run_commands(
"uv pip install --system 'transformers>=5.2.0' 'trl[quantization]' peft accelerate bitsandbytes jmespath",
"uv pip install --system 'git+https://github.com/meta-pytorch/OpenEnv.git' pydantic fastapi uvicorn openai python-dotenv tenacity",
)
.add_local_dir(
"/home/verma/projects/sakha",
remote_path="/sakha",
ignore=[
".git",
".venv",
"__pycache__",
".pytest_cache",
".ruff_cache",
"artifacts",
".sisyphus",
"*.pyc",
"*.pyo",
"output.md",
],
)
)
vol = modal.Volume.from_name("sakha-training", create_if_missing=True)
app = modal.App("sakha-grpo-quick", image=image)
@app.function(gpu="T4", timeout=1800, volumes={"/artifacts": vol})
def train(episodes: int = 10, max_steps: int = 12, model: str = "Qwen/Qwen3-0.6B") -> str:
import os
import subprocess
import json
import sys
from pathlib import Path
os.environ["PYTHONPATH"] = "/sakha/src"
os.environ["WANDB_MODE"] = "disabled"
os.environ["TRL_EXPERIMENTAL_SILENCE"] = "1"
sys.path.insert(0, "/sakha/src")
print("Installing sakha package...")
subprocess.run(
["pip", "install", "--no-deps", "-e", "/sakha"], stdout=sys.stdout, stderr=sys.stderr
)
output_dir = "/artifacts/quick_train"
cmd = [
"python",
"/sakha/scripts/train_grpo.py",
"--mode",
"smoke",
"--task",
"hard",
"--model",
model,
"--episodes",
str(episodes),
"--max-steps",
str(max_steps),
"--output-dir",
output_dir,
"--report-to",
"none",
]
print(f"Running training command: {' '.join(cmd)}")
result = subprocess.run(cmd, stdout=sys.stdout, stderr=sys.stderr, cwd="/sakha", env=os.environ)
results_files = list(Path(output_dir).rglob("results.json"))
results = {}
if results_files:
results = json.loads(max(results_files, key=lambda p: p.stat().st_mtime).read_text())
return json.dumps({"success": result.returncode == 0, "results": results})
@app.local_entrypoint()
def main(episodes: int = 10, max_steps: int = 12, model: str = "Qwen/Qwen3-0.6B"):
print(f"Starting quick training: {episodes} episodes, {max_steps} steps")
outcome = train.remote(episodes=episodes, max_steps=max_steps, model=model)
print(outcome)