| """ |
| train.py β GRPO Training Script for Planetary Rover Navigation |
| ================================================================ |
| |
| Uses Unsloth's FastLanguageModel + TRL's GRPOTrainer to fine-tune |
| meta-llama/Llama-3.2-1B-Instruct for autonomous rover navigation. |
| |
| Hardware target : NVIDIA RTX 3050 β strict 6 GB VRAM limit |
| Quantisation : 4-bit NF4 via Unsloth |
| LoRA : rank 16, attention + MLP projections |
| GRPO group size : 4 generations per prompt (prevents OOM) |
| |
| Reward functions |
| ---------------- |
| 1. Format Gatekeeper β validates <action>JSON</action> structure |
| 2. Environment Reward β POSTs parsed action to local physics server |
| |
| Prerequisites |
| ------------- |
| 1. Local server running: |
| uvicorn main:app --host 0.0.0.0 --port 7860 |
| 2. Python packages: |
| pip install unsloth trl datasets peft accelerate |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import math |
| import os |
| import wandb |
| import re |
| import sys |
| import time |
| import random |
| import logging |
| from numbers import Real |
| from typing import Any |
|
|
| import requests |
| import torch |
| from datasets import Dataset |
|
|
| |
| |
| |
| from unsloth import FastLanguageModel |
| from trl import GRPOConfig, GRPOTrainer |
| from transformers import TrainerCallback |
|
|
| |
| |
| |
|
|
| MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" |
| SERVER_URL = os.getenv("ROVER_SERVER_URL", "http://127.0.0.1:7860") |
| OUTPUT_DIR = "./grpo_rover_checkpoints" |
| SEED = 42 |
|
|
| |
| MAX_SEQ_LENGTH = 512 |
| LORA_RANK = 16 |
| LORA_ALPHA = 32 |
| LORA_DROPOUT = 0.0 |
|
|
| |
| NUM_TRAIN_EPISODES = 150 |
| MAX_PROMPT_LENGTH = 256 |
| MAX_COMPLETION_LENGTH = 256 |
| NUM_GENERATIONS = int(os.getenv("ROVER_NUM_GENERATIONS", "8")) |
| LEARNING_RATE = 1e-6 |
| KL_COEF = 0.04 |
| NUM_TRAIN_EPOCHS = 2 |
| PER_DEVICE_BATCH = 1 |
| GRAD_ACCUM_STEPS = int(os.getenv("ROVER_GRAD_ACCUM_STEPS", "8")) |
| WARMUP_STEPS = int(os.getenv("ROVER_WARMUP_STEPS", "10")) |
| USE_BF16 = os.getenv("ROVER_USE_BF16", "0") == "1" |
|
|
| |
| FORMAT_REWARD_GOOD = 1.0 |
| FORMAT_REWARD_BAD = 0.0 |
| VERBOSITY_THRESHOLD = 80 |
| VERBOSITY_PENALTY_K = 200 |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
| log = logging.getLogger("train") |
|
|
|
|
| def _extract_scalar_reward(logs: dict[str, Any]) -> float | None: |
| """Return one scalar reward value from a TRL/Trainer log payload.""" |
| raw_reward = logs.get("reward") |
| if isinstance(raw_reward, Real): |
| return float(raw_reward) |
|
|
| reward_terms: list[float] = [] |
| for key, value in logs.items(): |
| key_lower = key.lower() |
| if "reward" not in key_lower: |
| continue |
| if any(skip in key_lower for skip in ("std", "min", "max", "var")): |
| continue |
| if isinstance(value, Real): |
| reward_terms.append(float(value)) |
|
|
| if not reward_terms: |
| return None |
|
|
| return sum(reward_terms) / len(reward_terms) |
|
|
|
|
| class CompactMetricsCallback(TrainerCallback): |
| """Emit a concise log line to simplify screenshot capture in Spaces logs.""" |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if not logs: |
| return control |
|
|
| loss = logs.get("loss") |
| if not isinstance(loss, Real): |
| return control |
|
|
| metrics: dict[str, float] = {"loss": float(loss)} |
|
|
| reward = _extract_scalar_reward(logs) |
| if reward is not None: |
| metrics["reward"] = reward |
|
|
| learning_rate = logs.get("learning_rate") |
| if isinstance(learning_rate, Real): |
| metrics["lr"] = float(learning_rate) |
|
|
| compact = {key: round(value, 6) for key, value in metrics.items()} |
| log.info("METRICS %s", compact) |
| return control |
|
|
|
|
| |
| |
| |
|
|
| SYSTEM_PROMPT = """\ |
| You are a planetary rover navigation controller. |
| Respond ONLY with your action inside <action></action> tags as valid JSON. |
| |
| Action schema: |
| {"thrust": float[0,1], "steering": float[-1,1], "brake": 0|1, "vertical_thruster": float[-0.2,0.2]} |
| |
| Key physics: |
| - heading_error = atan2(target_dy, target_dx) - rover_heading |
| - steering β clamp(heading_error * 2.5, -1, 1) |
| - thrust=1.0 for progress; brake=0 unless overshooting |
| - If nearest_obstacle < 10m, steer perpendicular to dodge\ |
| """ |
|
|
|
|
| |
| |
| |
|
|
| def build_compact_prompt( |
| task_id: str, |
| obs: dict[str, Any], |
| step_num: int, |
| max_steps: int, |
| ) -> str: |
| """ |
| Build a token-efficient user prompt from an observation dict. |
| Designed to fit in ~100β120 tokens so system + user β€ 256. |
| """ |
| dx = obs["target_relative"]["x"] |
| dy = obs["target_relative"]["y"] |
|
|
| |
| target_heading = math.atan2(dy, dx) |
| raw_error = target_heading - obs["rover_heading"] |
| while raw_error > math.pi: raw_error -= 2 * math.pi |
| while raw_error <= -math.pi: raw_error += 2 * math.pi |
|
|
| suggested_steering = max(-1.0, min(1.0, raw_error * 2.5)) |
|
|
| return ( |
| f"TASK: {task_id} STEP: {step_num}/{max_steps}\n" |
| f"target_distance={obs['target_distance']:.1f}m " |
| f"heading_error={raw_error:.4f}rad\n" |
| f"battery={obs['battery_level']:.3f} " |
| f"nearest_obstacle={obs['nearest_obstacle_distance']:.1f}m " |
| f"terrain={obs['terrain_type']}\n" |
| f"suggested_steering={suggested_steering:.4f}\n" |
| f"Output your <action> JSON now." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| TASK_MAX_STEPS = {"easy": 200, "medium": 300, "hard": 100} |
|
|
|
|
| def _check_server() -> None: |
| """Fail fast if the environment server is unreachable.""" |
| try: |
| r = requests.get(f"{SERVER_URL}/tasks", timeout=5) |
| r.raise_for_status() |
| log.info("Environment server is live at %s", SERVER_URL) |
| except Exception as e: |
| log.error( |
| "Cannot reach environment server at %s β " |
| "start it with: uvicorn main:app --host 0.0.0.0 --port 7860", |
| SERVER_URL, |
| ) |
| sys.exit(1) |
|
|
|
|
| def generate_training_dataset(n_per_task: int = NUM_TRAIN_EPISODES) -> Dataset: |
| """ |
| Generate a training dataset by resetting episodes across all tasks. |
| |
| Each row contains: |
| prompt β chat-formatted messages (system + user) |
| task_id β for environment reward replay |
| seed β for environment reward replay |
| """ |
| rows: list[dict[str, Any]] = [] |
|
|
| for task_id in ["easy", "medium", "hard"]: |
| max_steps = TASK_MAX_STEPS[task_id] |
| for seed in range(n_per_task): |
| try: |
| resp = requests.post( |
| f"{SERVER_URL}/reset", |
| json={"task_id": task_id, "seed": seed}, |
| timeout=10, |
| ) |
| resp.raise_for_status() |
| data = resp.json() |
| except Exception as e: |
| log.warning("Reset failed (task=%s seed=%d): %s", task_id, seed, e) |
| continue |
|
|
| obs = data["obs"] |
| user_msg = build_compact_prompt(task_id, obs, step_num=1, max_steps=max_steps) |
|
|
| rows.append({ |
| "prompt": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_msg}, |
| ], |
| "task_id": task_id, |
| "seed": seed, |
| }) |
|
|
| random.shuffle(rows) |
| log.info("Generated %d training prompts (%d per task Γ 3 tasks)", len(rows), n_per_task) |
| return Dataset.from_list(rows) |
|
|
|
|
| |
| |
| |
|
|
| |
| _ACTION_RE = re.compile(r"<action>\s*(.*?)\s*</action>", re.DOTALL) |
|
|
| |
| _ACTION_FIELDS = { |
| "thrust": (0.0, 1.0), |
| "steering": (-1.0, 1.0), |
| "brake": (0, 1), |
| "vertical_thruster": (-0.2, 0.2), |
| } |
|
|
|
|
| def _completion_to_text(completion: Any) -> str: |
| """Convert TRL completion payloads (str/list/dict) into plain text.""" |
| if completion is None: |
| return "" |
|
|
| if isinstance(completion, str): |
| return completion |
|
|
| if isinstance(completion, bytes): |
| return completion.decode("utf-8", errors="ignore") |
|
|
| if isinstance(completion, dict): |
| for key in ("content", "text", "completion", "generated_text"): |
| if key in completion: |
| return _completion_to_text(completion[key]) |
| return str(completion) |
|
|
| if isinstance(completion, list): |
| parts = [_completion_to_text(item) for item in completion] |
| return "\n".join(part for part in parts if part) |
|
|
| return str(completion) |
|
|
|
|
| def parse_action_from_completion(completion: Any) -> dict[str, Any] | None: |
| """ |
| Extract and validate an action JSON from <action>β¦</action> tags. |
| |
| Returns the parsed action dict if valid, None otherwise. |
| """ |
| text = _completion_to_text(completion) |
| if not text: |
| return None |
|
|
| match = _ACTION_RE.search(text) |
| if not match: |
| return None |
|
|
| try: |
| parsed = json.loads(match.group(1)) |
| except json.JSONDecodeError: |
| return None |
|
|
| if not isinstance(parsed, dict): |
| return None |
|
|
| |
| action: dict[str, Any] = {} |
| for field, (lo, hi) in _ACTION_FIELDS.items(): |
| if field not in parsed: |
| return None |
| val = parsed[field] |
| try: |
| if field == "brake": |
| val = int(round(float(val))) |
| else: |
| val = float(val) |
| except (TypeError, ValueError): |
| return None |
| |
| if val < lo - 0.5 or val > hi + 0.5: |
| return None |
| |
| if field == "brake": |
| val = max(0, min(1, val)) |
| else: |
| val = max(lo, min(hi, val)) |
| action[field] = val |
|
|
| return action |
|
|
|
|
| def format_reward_fn(completions: list[Any], **kwargs) -> list[float]: |
| """ |
| Reward Function 1 β The Format Gatekeeper. |
| |
| Returns 1.0 if the completion contains valid <action>JSON</action> |
| matching the rover action schema. Returns 0.0 on failure. |
| |
| Applies a soft verbosity penalty: completions exceeding |
| VERBOSITY_THRESHOLD tokens are penalised linearly, reaching 0 |
| at VERBOSITY_THRESHOLD + VERBOSITY_PENALTY_K tokens. |
| """ |
| rewards: list[float] = [] |
|
|
| for completion in completions: |
| text = _completion_to_text(completion) |
| action = parse_action_from_completion(text) |
| if action is None: |
| rewards.append(FORMAT_REWARD_BAD) |
| continue |
|
|
| |
| base = FORMAT_REWARD_GOOD |
|
|
| |
| |
| token_estimate = len(text.split()) |
| if token_estimate > VERBOSITY_THRESHOLD: |
| excess = token_estimate - VERBOSITY_THRESHOLD |
| penalty = max(0.0, 1.0 - excess / VERBOSITY_PENALTY_K) |
| base *= penalty |
|
|
| rewards.append(base) |
|
|
| return rewards |
|
|
|
|
| |
| |
| |
|
|
| def environment_reward_fn(completions: list[Any], **kwargs) -> list[float]: |
| """ |
| Reward Function 2 β The Environment. |
| |
| For each completion: |
| 1. Parse the action from <action> tags. |
| 2. Reset a fresh episode with the same (task_id, seed) as the prompt. |
| 3. POST the action to /step. |
| 4. Return the scalar step reward from the physics engine. |
| |
| If parsing or HTTP fails, returns 0.0 (neutral β no signal). |
| """ |
| task_ids: list[str] = kwargs.get("task_id", []) |
| seeds: list[int] = kwargs.get("seed", []) |
|
|
| rewards: list[float] = [] |
|
|
| for i, completion in enumerate(completions): |
| |
| action = parse_action_from_completion(completion) |
| if action is None: |
| rewards.append(0.0) |
| continue |
|
|
| |
| |
| |
| task_id = task_ids[i] if i < len(task_ids) else "easy" |
| seed = seeds[i] if i < len(seeds) else 0 |
|
|
| try: |
| |
| reset_resp = requests.post( |
| f"{SERVER_URL}/reset", |
| json={"task_id": task_id, "seed": seed}, |
| timeout=10, |
| ) |
| reset_resp.raise_for_status() |
| episode_id = reset_resp.json()["episode_id"] |
|
|
| |
| step_resp = requests.post( |
| f"{SERVER_URL}/step", |
| json=action, |
| params={"episode_id": episode_id}, |
| timeout=10, |
| ) |
| step_resp.raise_for_status() |
| step_data = step_resp.json() |
|
|
| |
| reward = float(step_data.get("reward", 0.0)) |
| rewards.append(reward) |
|
|
| except Exception as e: |
| log.warning("Environment reward failed (task=%s seed=%d): %s", task_id, seed, e) |
| rewards.append(0.0) |
|
|
| return rewards |
|
|
|
|
| |
| |
| |
|
|
| def load_model(): |
| """ |
| Load Llama-3.2-1B-Instruct with Unsloth's 4-bit NF4 quantisation |
| and attach LoRA adapters to attention + MLP projections. |
| """ |
| log.info("Loading %s with 4-bit NF4 quantisation via Unslothβ¦", MODEL_NAME) |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name = MODEL_NAME, |
| max_seq_length = MAX_SEQ_LENGTH, |
| |
| dtype = torch.bfloat16 if USE_BF16 else torch.float16, |
| load_in_4bit = True, |
| ) |
|
|
| log.info("Attaching LoRA (rank=%d, alpha=%d) to attention + MLPβ¦", LORA_RANK, LORA_ALPHA) |
|
|
| model = FastLanguageModel.get_peft_model( |
| model, |
| r = LORA_RANK, |
| target_modules = [ |
| |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| lora_alpha = LORA_ALPHA, |
| lora_dropout = LORA_DROPOUT, |
| bias = "none", |
| use_gradient_checkpointing = "unsloth", |
| random_state = SEED, |
| ) |
|
|
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "left" |
|
|
| vram_gb = torch.cuda.memory_allocated() / 1e9 |
| log.info("Model loaded. Current VRAM: %.2f GB", vram_gb) |
|
|
| return model, tokenizer |
|
|
|
|
| |
| |
| |
|
|
| def build_training_config() -> GRPOConfig: |
| """Build the GRPOConfig with parameters safe for 6 GB VRAM.""" |
| |
| use_bf16 = USE_BF16 |
|
|
| return GRPOConfig( |
| output_dir = OUTPUT_DIR, |
|
|
| |
| num_generations = NUM_GENERATIONS, |
| max_prompt_length = MAX_PROMPT_LENGTH, |
| max_completion_length = MAX_COMPLETION_LENGTH, |
| beta = KL_COEF, |
|
|
| |
| learning_rate = LEARNING_RATE, |
| lr_scheduler_type = "cosine", |
| warmup_steps = WARMUP_STEPS, |
| max_grad_norm = 1.0, |
|
|
| |
| per_device_train_batch_size = PER_DEVICE_BATCH, |
| |
| gradient_accumulation_steps = GRAD_ACCUM_STEPS, |
| num_train_epochs = NUM_TRAIN_EPOCHS, |
|
|
| |
| bf16 = use_bf16, |
| fp16 = not use_bf16, |
|
|
| |
| logging_steps = 1, |
| save_steps = 50, |
| save_total_limit = 3, |
| report_to = "wandb", |
| run_name = "openenv-rover-run", |
| seed = SEED, |
|
|
| |
| remove_unused_columns = False, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| log.info("=" * 60) |
| log.info("GRPO Training β Planetary Rover Navigation") |
| log.info("Model : %s", MODEL_NAME) |
| log.info("VRAM : 24 GB+ cloud GPU (4-bit NF4, LoRA r=%d, group=%d)", |
| LORA_RANK, NUM_GENERATIONS) |
| log.info("Precision: %s", "bf16" if USE_BF16 else "fp16") |
| log.info("=" * 60) |
|
|
| |
| _check_server() |
|
|
| |
| model, tokenizer = load_model() |
|
|
| |
| log.info("Generating full training dataset from physics engine...") |
| train_dataset = generate_training_dataset() |
|
|
| |
| config = build_training_config() |
|
|
| |
| log.info("Initialising GRPOTrainer with 2 reward functionsβ¦") |
| trainer = GRPOTrainer( |
| model = model, |
| tokenizer = tokenizer, |
| reward_funcs = [format_reward_fn, environment_reward_fn], |
| args = config, |
| train_dataset = train_dataset, |
| ) |
| trainer.add_callback(CompactMetricsCallback()) |
|
|
| |
| log.info("Starting GRPO trainingβ¦") |
| start = time.time() |
| trainer.train() |
| elapsed = time.time() - start |
| log.info("Training complete in %.1f minutes.", elapsed / 60) |
|
|
| |
| final_path = os.path.join(OUTPUT_DIR, "final_adapter") |
| model.save_pretrained(final_path) |
| tokenizer.save_pretrained(final_path) |
| log.info("Final LoRA adapter saved to %s", final_path) |
|
|
| |
| peak_vram = torch.cuda.max_memory_allocated() / 1e9 |
| log.info("Peak VRAM usage: %.2f GB", peak_vram) |
| if peak_vram > 24.0: |
| log.warning("β Peak VRAM exceeded 24 GB! Reduce NUM_GENERATIONS or LORA_RANK.") |
| else: |
| log.info("β
VRAM within cloud GPU budget.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|