|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Simple GPU training script for Bamboo-1 using RunPod. |
|
|
|
|
|
Usage: |
|
|
export RUNPOD_API_KEY="your-api-key" |
|
|
uv run scripts/train_gpu.py |
|
|
uv run scripts/train_gpu.py --gpu "NVIDIA RTX 3090" |
|
|
uv run scripts/train_gpu.py --feat bert --max-epochs 50 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import click |
|
|
import runpod |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option("--gpu", default="NVIDIA RTX A4000", help="GPU type") |
|
|
@click.option("--feat", type=click.Choice(["char", "bert"]), default="char", help="Feature type") |
|
|
@click.option("--max-epochs", default=100, type=int, help="Max training epochs") |
|
|
@click.option("--batch-size", default=5000, type=int, help="Tokens per batch") |
|
|
@click.option("--name", default="bamboo-1-train", help="Pod name") |
|
|
def main(gpu, feat, max_epochs, batch_size, name): |
|
|
"""Launch Bamboo-1 training on RunPod GPU.""" |
|
|
api_key = os.environ.get("RUNPOD_API_KEY") |
|
|
if not api_key: |
|
|
raise click.ClickException( |
|
|
"Set RUNPOD_API_KEY environment variable.\n" |
|
|
"Get your key at: https://runpod.io/console/user/settings" |
|
|
) |
|
|
|
|
|
runpod.api_key = api_key |
|
|
|
|
|
|
|
|
train_cmd = ( |
|
|
f"curl -LsSf https://astral.sh/uv/install.sh | sh && " |
|
|
f"source $HOME/.local/bin/env && " |
|
|
f"git clone https://huggingface.co/undertheseanlp/bamboo-1 && " |
|
|
f"cd bamboo-1 && " |
|
|
f"uv sync && " |
|
|
f"uv run scripts/train.py --output models/bamboo-1 --feat {feat} --max-epochs {max_epochs} --batch-size {batch_size}" |
|
|
) |
|
|
|
|
|
click.echo("Launching RunPod training...") |
|
|
click.echo(f" GPU: {gpu}") |
|
|
click.echo(f" Feature: {feat}") |
|
|
click.echo(f" Epochs: {max_epochs}") |
|
|
|
|
|
pod = runpod.create_pod( |
|
|
name=name, |
|
|
image_name="runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04", |
|
|
gpu_type_id=gpu, |
|
|
volume_in_gb=20, |
|
|
docker_args=train_cmd, |
|
|
) |
|
|
|
|
|
click.echo(f"\nPod launched!") |
|
|
click.echo(f" ID: {pod['id']}") |
|
|
click.echo(f" Monitor: https://runpod.io/console/pods") |
|
|
click.echo(f"\nTo stop: uv run scripts/runpod_setup.py terminate {pod['id']}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|