File size: 2,255 Bytes
b85c683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "runpod>=1.6.0",
#     "click>=8.0.0",
# ]
# ///
"""
Simple GPU training script for Bamboo-1 using RunPod.

Usage:
    export RUNPOD_API_KEY="your-api-key"
    uv run scripts/train_gpu.py
    uv run scripts/train_gpu.py --gpu "NVIDIA RTX 3090"
    uv run scripts/train_gpu.py --feat bert --max-epochs 50
"""

import os
import click
import runpod


@click.command()
@click.option("--gpu", default="NVIDIA RTX A4000", help="GPU type")
@click.option("--feat", type=click.Choice(["char", "bert"]), default="char", help="Feature type")
@click.option("--max-epochs", default=100, type=int, help="Max training epochs")
@click.option("--batch-size", default=5000, type=int, help="Tokens per batch")
@click.option("--name", default="bamboo-1-train", help="Pod name")
def main(gpu, feat, max_epochs, batch_size, name):
    """Launch Bamboo-1 training on RunPod GPU."""
    api_key = os.environ.get("RUNPOD_API_KEY")
    if not api_key:
        raise click.ClickException(
            "Set RUNPOD_API_KEY environment variable.\n"
            "Get your key at: https://runpod.io/console/user/settings"
        )

    runpod.api_key = api_key

    # One-liner to avoid string escaping issues
    train_cmd = (
        f"curl -LsSf https://astral.sh/uv/install.sh | sh && "
        f"source $HOME/.local/bin/env && "
        f"git clone https://huggingface.co/undertheseanlp/bamboo-1 && "
        f"cd bamboo-1 && "
        f"uv sync && "
        f"uv run scripts/train.py --output models/bamboo-1 --feat {feat} --max-epochs {max_epochs} --batch-size {batch_size}"
    )

    click.echo("Launching RunPod training...")
    click.echo(f"  GPU: {gpu}")
    click.echo(f"  Feature: {feat}")
    click.echo(f"  Epochs: {max_epochs}")

    pod = runpod.create_pod(
        name=name,
        image_name="runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04",
        gpu_type_id=gpu,
        volume_in_gb=20,
        docker_args=train_cmd,
    )

    click.echo(f"\nPod launched!")
    click.echo(f"  ID: {pod['id']}")
    click.echo(f"  Monitor: https://runpod.io/console/pods")
    click.echo(f"\nTo stop: uv run scripts/runpod_setup.py terminate {pod['id']}")


if __name__ == "__main__":
    main()