Bhaskar commited on
Commit Β·
51bb0d4
1
Parent(s): d29a040
Round 2 Upgrade: Added GRPO train.py and vector-field reward shaping
Browse files- .gitignore +3 -1
- README.md +14 -12
- main.py +132 -12
- openenv.yaml +18 -2
- train.py +523 -0
.gitignore
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
env/
|
| 2 |
__pycache__/
|
| 3 |
-
.env
|
|
|
|
|
|
|
|
|
| 1 |
env/
|
| 2 |
__pycache__/
|
| 3 |
+
.env
|
| 4 |
+
|
| 5 |
+
*.egg-info/
|
README.md
CHANGED
|
@@ -335,15 +335,16 @@ step_efficiency = 1.0 β (steps_taken / max_steps)
|
|
| 335 |
|
| 336 |
The step reward returned by `/step` is used for online RL training. It is separate from the grader score.
|
| 337 |
|
| 338 |
-
| Event | Reward |
|
| 339 |
-
|---|---|
|
| 340 |
-
| Every step | β0.01
|
| 341 |
-
| Battery drain | βdrain Γ 2.0
|
| 342 |
-
| Waypoint reached | +
|
| 343 |
-
| Obstacle collision | β5.0 |
|
| 344 |
-
| Battery depleted | β20.0 |
|
| 345 |
-
|
|
| 346 |
-
|
|
|
|
|
| 347 |
|
| 348 |
---
|
| 349 |
|
|
@@ -352,9 +353,10 @@ The step reward returned by `/step` is used for online RL training. It is separa
|
|
| 352 |
```
|
| 353 |
planetary-rover-env/
|
| 354 |
βββ openenv.yaml # Typed observation + action space declarations
|
| 355 |
-
βββ main.py # FastAPI server β physics engine + all routes (
|
| 356 |
-
βββ
|
| 357 |
-
βββ
|
|
|
|
| 358 |
βββ Dockerfile # Two-stage optimised build, port 7860, non-root user
|
| 359 |
βββ README.md # This file
|
| 360 |
```
|
|
|
|
| 335 |
|
| 336 |
The step reward returned by `/step` is used for online RL training. It is separate from the grader score.
|
| 337 |
|
| 338 |
+
| Event | Reward | Notes |
|
| 339 |
+
|---|---|---|
|
| 340 |
+
| Every step | β0.01 | Constant time-pressure penalty |
|
| 341 |
+
| Battery drain | βdrain Γ 2.0 | Efficiency incentive |
|
| 342 |
+
| **Waypoint reached** | **+100.0** | Massive asymmetric reward to prevent early policy collapse |
|
| 343 |
+
| Obstacle collision | β5.0 | Speed zeroed, micro battery penalty |
|
| 344 |
+
| Battery depleted | β20.0 | Terminal penalty |
|
| 345 |
+
| **Potential-based distance shaping** | `(prev_dist β curr_dist) / initial_dist` | Positive when closing distance; **zero when stationary** (defeats the "stand still" exploit) |
|
| 346 |
+
| **Vector-field shaping (near obstacles)** | up to +0.3 | Active within 10 m of obstacles; rewards cosine similarity between rover heading and computed tangent vector (repulsive + attractive gradient blend) |
|
| 347 |
+
| Episode complete in < 50% of budget | +5.0 | Efficiency bonus |
|
| 348 |
|
| 349 |
---
|
| 350 |
|
|
|
|
| 353 |
```
|
| 354 |
planetary-rover-env/
|
| 355 |
βββ openenv.yaml # Typed observation + action space declarations
|
| 356 |
+
βββ main.py # FastAPI server β physics engine + all routes (1632 lines)
|
| 357 |
+
βββ inference.py # LLM-driven inference agent (HF Inference API)
|
| 358 |
+
βββ train.py # GRPO training script (Unsloth 4-bit + TRL GRPOTrainer)
|
| 359 |
+
βββ requirements.txt # Pinned runtime dependencies
|
| 360 |
βββ Dockerfile # Two-stage optimised build, port 7860, non-root user
|
| 361 |
βββ README.md # This file
|
| 362 |
```
|
main.py
CHANGED
|
@@ -567,6 +567,9 @@ class RoverSim:
|
|
| 567 |
min_distance: float = 0.0 # running minimum; drives partial-progress score
|
| 568 |
collision_count: int = 0 # cumulative obstacle contacts
|
| 569 |
|
|
|
|
|
|
|
|
|
|
| 570 |
# -------------------------------------------------------------------
|
| 571 |
# Helpers
|
| 572 |
# -------------------------------------------------------------------
|
|
@@ -674,20 +677,129 @@ class RoverSim:
|
|
| 674 |
# Reward (called by step)
|
| 675 |
# -------------------------------------------------------------------
|
| 676 |
|
| 677 |
-
def _compute_reward(
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
wp = self.active_waypoint
|
| 685 |
if wp:
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
r += 5.0
|
|
|
|
| 691 |
return r
|
| 692 |
|
| 693 |
# -------------------------------------------------------------------
|
|
@@ -770,6 +882,10 @@ class RoverSim:
|
|
| 770 |
|
| 771 |
self.steps += 1
|
| 772 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
self._apply_kinematics(action)
|
| 774 |
drain = self._update_battery(action)
|
| 775 |
collided, nd = self._check_collision()
|
|
@@ -781,6 +897,9 @@ class RoverSim:
|
|
| 781 |
if current_dist < self.min_distance:
|
| 782 |
self.min_distance = current_dist
|
| 783 |
|
|
|
|
|
|
|
|
|
|
| 784 |
all_done = self.waypoints_hit == len(self.waypoint_list)
|
| 785 |
batt_dead = self.battery <= 0.0
|
| 786 |
self.done = all_done or batt_dead
|
|
@@ -797,7 +916,7 @@ class RoverSim:
|
|
| 797 |
else:
|
| 798 |
termination_reason = "unknown"
|
| 799 |
|
| 800 |
-
reward = self._compute_reward(wp_hit, collided, drain)
|
| 801 |
self.total_reward += reward
|
| 802 |
|
| 803 |
obs = self._build_obs()
|
|
@@ -908,6 +1027,7 @@ def _make_sim(task_id: str, seed: int | None) -> RoverSim:
|
|
| 908 |
done=False, truncated=False, waypoints_hit=0,
|
| 909 |
initial_distance=initial_dist,
|
| 910 |
min_distance=initial_dist,
|
|
|
|
| 911 |
collision_count=0,
|
| 912 |
)
|
| 913 |
|
|
|
|
| 567 |
min_distance: float = 0.0 # running minimum; drives partial-progress score
|
| 568 |
collision_count: int = 0 # cumulative obstacle contacts
|
| 569 |
|
| 570 |
+
# Reward-shaping state β tracks distance at previous step for potential-based shaping
|
| 571 |
+
_prev_distance: float = 0.0 # set equal to initial_distance at reset
|
| 572 |
+
|
| 573 |
# -------------------------------------------------------------------
|
| 574 |
# Helpers
|
| 575 |
# -------------------------------------------------------------------
|
|
|
|
| 677 |
# Reward (called by step)
|
| 678 |
# -------------------------------------------------------------------
|
| 679 |
|
| 680 |
+
def _compute_reward(
|
| 681 |
+
self,
|
| 682 |
+
waypoint_hit: bool,
|
| 683 |
+
collided: bool,
|
| 684 |
+
drain: float,
|
| 685 |
+
prev_dist: float,
|
| 686 |
+
) -> float:
|
| 687 |
+
"""
|
| 688 |
+
Upgraded reward with two anti-exploit shaping mechanisms:
|
| 689 |
+
|
| 690 |
+
1. **Potential-Based Reward Shaping (flat plains)**
|
| 691 |
+
Ξ¦(s) = βdistance_to_goal. Shaping = Ξ³Ξ¦(s') β Ξ¦(s) β prev_dist β curr_dist.
|
| 692 |
+
If the rover stands still, curr_dist == prev_dist β shaping = 0,
|
| 693 |
+
so the step penalty + battery drain yield a guaranteed net negative.
|
| 694 |
+
|
| 695 |
+
2. **Vector-Field Reward Shaping (craters / obstacles)**
|
| 696 |
+
When any obstacle is within 10 m, compute:
|
| 697 |
+
β’ Attractive gradient g_a = normalise(goal β pos)
|
| 698 |
+
β’ Repulsive gradient g_r = Ξ£ (1/dΒ² β 1/DΒ²) Β· normalise(pos β obs)
|
| 699 |
+
Blend into a combined desired vector, take its orthogonal tangent
|
| 700 |
+
(so the rover flows *around* obstacles rather than into them),
|
| 701 |
+
and reward based on cosine similarity between the rover's actual
|
| 702 |
+
heading vector and the tangent vector.
|
| 703 |
+
|
| 704 |
+
The massive +100.0 asymmetric waypoint reward is preserved to
|
| 705 |
+
anchor the policy toward goal completion.
|
| 706 |
+
"""
|
| 707 |
+
r = 0.0
|
| 708 |
+
|
| 709 |
+
# ββ 0. Constant step cost (time pressure) ββββββββββββββββββββββ
|
| 710 |
+
r -= 0.01
|
| 711 |
+
|
| 712 |
+
# ββ 1. Battery efficiency penalty ββββββββββββββββββββββββββββββ
|
| 713 |
+
r -= drain * 2.0
|
| 714 |
+
if self.battery <= 0.0:
|
| 715 |
+
r -= 20.0
|
| 716 |
+
|
| 717 |
+
# ββ 2. Collision penalty βββββββββββββββββββββββββββββββββββββββ
|
| 718 |
+
if collided:
|
| 719 |
+
r -= 5.0
|
| 720 |
+
|
| 721 |
+
# ββ 3. Waypoint reached β massive asymmetric reward βββββββββββ
|
| 722 |
+
if waypoint_hit:
|
| 723 |
+
r += 100.0
|
| 724 |
+
|
| 725 |
+
# ββ 4. Potential-based distance shaping ββββββββββββββββββββββββ
|
| 726 |
+
# Ξ¦(s) = βdist β F_shape = Ξ¦(s') β Ξ¦(s) = prev_dist β curr_dist
|
| 727 |
+
# Stationary rover: curr == prev β shaping = 0 β net reward < 0
|
| 728 |
wp = self.active_waypoint
|
| 729 |
if wp:
|
| 730 |
+
curr_dist = math.hypot(wp[0] - self.px, wp[1] - self.py)
|
| 731 |
+
# Scale by 1/initial_distance so shaping magnitude is
|
| 732 |
+
# independent of spawn distance (reward β roughly [-1, +1])
|
| 733 |
+
scale = 1.0 / max(self.initial_distance, 1.0)
|
| 734 |
+
distance_shaping = (prev_dist - curr_dist) * scale
|
| 735 |
+
r += distance_shaping
|
| 736 |
+
else:
|
| 737 |
+
curr_dist = 0.0
|
| 738 |
+
|
| 739 |
+
# ββ 5. Vector-field shaping near obstacles (within 10 m) βββββββ
|
| 740 |
+
INFLUENCE_RADIUS = 10.0
|
| 741 |
+
nearest_obs = self.obstacles.nearest_n(self.px, self.py, 8)
|
| 742 |
+
close_obstacles = [(dx, dy, d) for dx, dy, d in nearest_obs
|
| 743 |
+
if d < INFLUENCE_RADIUS and d > 1e-6]
|
| 744 |
+
|
| 745 |
+
if close_obstacles and wp:
|
| 746 |
+
# 5a. Attractive gradient: unit vector toward goal
|
| 747 |
+
g_ax = wp[0] - self.px
|
| 748 |
+
g_ay = wp[1] - self.py
|
| 749 |
+
g_a_mag = math.hypot(g_ax, g_ay)
|
| 750 |
+
if g_a_mag > 1e-6:
|
| 751 |
+
g_ax /= g_a_mag
|
| 752 |
+
g_ay /= g_a_mag
|
| 753 |
+
else:
|
| 754 |
+
g_ax, g_ay = 0.0, 0.0
|
| 755 |
+
|
| 756 |
+
# 5b. Repulsive gradient: sum of inverse-square repulsions
|
| 757 |
+
# g_r = Ξ£_i (1/d_iΒ² β 1/DΒ²) Β· normalise(pos β obs_i)
|
| 758 |
+
D = INFLUENCE_RADIUS
|
| 759 |
+
g_rx, g_ry = 0.0, 0.0
|
| 760 |
+
for dx, dy, d in close_obstacles:
|
| 761 |
+
# dx, dy point FROM rover TO obstacle; we want FROM obstacle
|
| 762 |
+
repel_x, repel_y = -dx, -dy
|
| 763 |
+
rep_mag = math.hypot(repel_x, repel_y)
|
| 764 |
+
if rep_mag > 1e-6:
|
| 765 |
+
repel_x /= rep_mag
|
| 766 |
+
repel_y /= rep_mag
|
| 767 |
+
strength = (1.0 / (d * d)) - (1.0 / (D * D))
|
| 768 |
+
g_rx += strength * repel_x
|
| 769 |
+
g_ry += strength * repel_y
|
| 770 |
+
|
| 771 |
+
# 5c. Blend attractive + repulsive into desired vector
|
| 772 |
+
alpha = 0.5 # blending weight for repulsive component
|
| 773 |
+
blend_x = g_ax + alpha * g_rx
|
| 774 |
+
blend_y = g_ay + alpha * g_ry
|
| 775 |
+
|
| 776 |
+
# 5d. Compute tangent (90Β° CCW rotation of the blended vector)
|
| 777 |
+
# so the rover is guided to flow *around* the obstacle field
|
| 778 |
+
tangent_x = -blend_y
|
| 779 |
+
tangent_y = blend_x
|
| 780 |
+
t_mag = math.hypot(tangent_x, tangent_y)
|
| 781 |
+
if t_mag > 1e-6:
|
| 782 |
+
tangent_x /= t_mag
|
| 783 |
+
tangent_y /= t_mag
|
| 784 |
+
|
| 785 |
+
# 5e. Rover's actual heading unit vector
|
| 786 |
+
hx = math.cos(self.heading)
|
| 787 |
+
hy = math.sin(self.heading)
|
| 788 |
+
|
| 789 |
+
# 5f. Cosine similarity (absolute value β either tangent
|
| 790 |
+
# direction is acceptable, clockwise or counter-clockwise)
|
| 791 |
+
cos_sim = abs(hx * tangent_x + hy * tangent_y)
|
| 792 |
+
|
| 793 |
+
# Scale reward by proximity urgency: closer β stronger signal
|
| 794 |
+
min_d = close_obstacles[0][2] # already sorted ascending
|
| 795 |
+
proximity_weight = 1.0 - (min_d / INFLUENCE_RADIUS)
|
| 796 |
+
r += 0.3 * cos_sim * proximity_weight
|
| 797 |
+
|
| 798 |
+
# ββ 6. Efficiency bonus: episode done in < 50% of step budget β
|
| 799 |
+
if (self.waypoints_hit == len(self.waypoint_list)
|
| 800 |
+
and self.steps < self.max_steps * 0.5):
|
| 801 |
r += 5.0
|
| 802 |
+
|
| 803 |
return r
|
| 804 |
|
| 805 |
# -------------------------------------------------------------------
|
|
|
|
| 882 |
|
| 883 |
self.steps += 1
|
| 884 |
|
| 885 |
+
# Snapshot distance BEFORE kinematics so potential-based shaping
|
| 886 |
+
# can compute Ξd = prev_dist β curr_dist for this step.
|
| 887 |
+
prev_dist = self._prev_distance
|
| 888 |
+
|
| 889 |
self._apply_kinematics(action)
|
| 890 |
drain = self._update_battery(action)
|
| 891 |
collided, nd = self._check_collision()
|
|
|
|
| 897 |
if current_dist < self.min_distance:
|
| 898 |
self.min_distance = current_dist
|
| 899 |
|
| 900 |
+
# Update _prev_distance for the NEXT step's shaping computation
|
| 901 |
+
self._prev_distance = current_dist
|
| 902 |
+
|
| 903 |
all_done = self.waypoints_hit == len(self.waypoint_list)
|
| 904 |
batt_dead = self.battery <= 0.0
|
| 905 |
self.done = all_done or batt_dead
|
|
|
|
| 916 |
else:
|
| 917 |
termination_reason = "unknown"
|
| 918 |
|
| 919 |
+
reward = self._compute_reward(wp_hit, collided, drain, prev_dist)
|
| 920 |
self.total_reward += reward
|
| 921 |
|
| 922 |
obs = self._build_obs()
|
|
|
|
| 1027 |
done=False, truncated=False, waypoints_hit=0,
|
| 1028 |
initial_distance=initial_dist,
|
| 1029 |
min_distance=initial_dist,
|
| 1030 |
+
_prev_distance=initial_dist,
|
| 1031 |
collision_count=0,
|
| 1032 |
)
|
| 1033 |
|
openenv.yaml
CHANGED
|
@@ -281,11 +281,13 @@ reward:
|
|
| 281 |
description: >
|
| 282 |
Step reward signal returned in the 'reward' field of step().
|
| 283 |
The /grader endpoint computes the normalised episode score [0.0, 1.0]
|
| 284 |
-
from the full trajectory.
|
|
|
|
| 285 |
components:
|
| 286 |
waypoint_reached:
|
| 287 |
-
value: +
|
| 288 |
condition: "target_distance < 2.0 meters"
|
|
|
|
| 289 |
step_penalty:
|
| 290 |
value: -0.01
|
| 291 |
condition: "every step"
|
|
@@ -295,6 +297,20 @@ reward:
|
|
| 295 |
battery_depleted:
|
| 296 |
value: -20.0
|
| 297 |
condition: "battery_level == 0.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
efficiency_bonus:
|
| 299 |
value: +5.0
|
| 300 |
condition: "episode completed in < 50% of max_steps"
|
|
|
|
| 281 |
description: >
|
| 282 |
Step reward signal returned in the 'reward' field of step().
|
| 283 |
The /grader endpoint computes the normalised episode score [0.0, 1.0]
|
| 284 |
+
from the full trajectory. Reward shaping uses potential-based and
|
| 285 |
+
vector-field techniques to prevent the "stationary exploit".
|
| 286 |
components:
|
| 287 |
waypoint_reached:
|
| 288 |
+
value: +100.0
|
| 289 |
condition: "target_distance < 2.0 meters"
|
| 290 |
+
note: "Massive asymmetric reward prevents early policy collapse."
|
| 291 |
step_penalty:
|
| 292 |
value: -0.01
|
| 293 |
condition: "every step"
|
|
|
|
| 297 |
battery_depleted:
|
| 298 |
value: -20.0
|
| 299 |
condition: "battery_level == 0.0"
|
| 300 |
+
potential_based_distance_shaping:
|
| 301 |
+
value: "(prev_dist - curr_dist) / initial_distance"
|
| 302 |
+
condition: "every step while waypoint is active"
|
| 303 |
+
note: >
|
| 304 |
+
Ξ¦(s) = βdistance. Shaping = Ξ¦(s') β Ξ¦(s) = prev_dist β curr_dist.
|
| 305 |
+
Normalised by initial_distance for spawn-distance independence.
|
| 306 |
+
Standing still yields shaping = 0, so step penalty + drain = net negative.
|
| 307 |
+
vector_field_obstacle_shaping:
|
| 308 |
+
value: "up to +0.3"
|
| 309 |
+
condition: "any obstacle within 10 metres"
|
| 310 |
+
note: >
|
| 311 |
+
Computes attractive (goal) + repulsive (obstacles) gradient blend,
|
| 312 |
+
takes orthogonal tangent, rewards cosine similarity with rover heading.
|
| 313 |
+
Scaled by proximity urgency (closer obstacle = stronger signal).
|
| 314 |
efficiency_bonus:
|
| 315 |
value: +5.0
|
| 316 |
condition: "episode completed in < 50% of max_steps"
|
train.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train.py β GRPO Training Script for Planetary Rover Navigation
|
| 3 |
+
================================================================
|
| 4 |
+
|
| 5 |
+
Uses Unsloth's FastLanguageModel + TRL's GRPOTrainer to fine-tune
|
| 6 |
+
meta-llama/Llama-3.2-1B-Instruct for autonomous rover navigation.
|
| 7 |
+
|
| 8 |
+
Hardware target : NVIDIA RTX 3050 β strict 6 GB VRAM limit
|
| 9 |
+
Quantisation : 4-bit NF4 via Unsloth
|
| 10 |
+
LoRA : rank 16, attention + MLP projections
|
| 11 |
+
GRPO group size : 4 generations per prompt (prevents OOM)
|
| 12 |
+
|
| 13 |
+
Reward functions
|
| 14 |
+
----------------
|
| 15 |
+
1. Format Gatekeeper β validates <action>JSON</action> structure
|
| 16 |
+
2. Environment Reward β POSTs parsed action to local physics server
|
| 17 |
+
|
| 18 |
+
Prerequisites
|
| 19 |
+
-------------
|
| 20 |
+
1. Local server running:
|
| 21 |
+
uvicorn main:app --host 0.0.0.0 --port 7860
|
| 22 |
+
2. Python packages:
|
| 23 |
+
pip install unsloth trl datasets peft accelerate
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import json
|
| 29 |
+
import math
|
| 30 |
+
import os
|
| 31 |
+
import re
|
| 32 |
+
import sys
|
| 33 |
+
import time
|
| 34 |
+
import random
|
| 35 |
+
import logging
|
| 36 |
+
from typing import Any
|
| 37 |
+
|
| 38 |
+
import requests
|
| 39 |
+
import torch
|
| 40 |
+
from datasets import Dataset
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Unsloth + TRL imports (deferred to allow --help without GPU)
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
from unsloth import FastLanguageModel
|
| 46 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# Configuration
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
| 53 |
+
SERVER_URL = "http://0.0.0.0:7860"
|
| 54 |
+
OUTPUT_DIR = "./grpo_rover_checkpoints"
|
| 55 |
+
SEED = 42
|
| 56 |
+
|
| 57 |
+
# VRAM-safe parameters for RTX 3050 (6 GB)
|
| 58 |
+
MAX_SEQ_LENGTH = 512 # prompt + completion combined
|
| 59 |
+
LORA_RANK = 16
|
| 60 |
+
LORA_ALPHA = 32
|
| 61 |
+
LORA_DROPOUT = 0.0
|
| 62 |
+
|
| 63 |
+
# Training hyperparameters
|
| 64 |
+
NUM_TRAIN_EPISODES = 150 # prompts per task Γ 3 tasks = total dataset
|
| 65 |
+
MAX_PROMPT_LENGTH = 256
|
| 66 |
+
MAX_COMPLETION_LENGTH = 256
|
| 67 |
+
NUM_GENERATIONS = 4 # GRPO group size β critical for 6 GB
|
| 68 |
+
LEARNING_RATE = 1e-6
|
| 69 |
+
KL_COEF = 0.04 # Ξ² for KL penalty
|
| 70 |
+
NUM_TRAIN_EPOCHS = 2
|
| 71 |
+
PER_DEVICE_BATCH = 1 # keep at 1 for 6 GB
|
| 72 |
+
GRAD_ACCUM_STEPS = 4
|
| 73 |
+
|
| 74 |
+
# Reward tuning
|
| 75 |
+
FORMAT_REWARD_GOOD = 1.0
|
| 76 |
+
FORMAT_REWARD_BAD = 0.0
|
| 77 |
+
VERBOSITY_THRESHOLD = 80 # tokens β a valid <action>{β¦}</action> is ~30-40
|
| 78 |
+
VERBOSITY_PENALTY_K = 200 # excess tokens before reward β 0
|
| 79 |
+
|
| 80 |
+
# Logging
|
| 81 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 82 |
+
log = logging.getLogger("train")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# =============================================================================
|
| 86 |
+
# System prompt (compact β must fit within ~90 tokens so user prompt has room)
|
| 87 |
+
# =============================================================================
|
| 88 |
+
|
| 89 |
+
SYSTEM_PROMPT = """\
|
| 90 |
+
You are a planetary rover navigation controller.
|
| 91 |
+
Respond ONLY with your action inside <action></action> tags as valid JSON.
|
| 92 |
+
|
| 93 |
+
Action schema:
|
| 94 |
+
{"thrust": float[0,1], "steering": float[-1,1], "brake": 0|1, "vertical_thruster": float[-0.2,0.2]}
|
| 95 |
+
|
| 96 |
+
Key physics:
|
| 97 |
+
- heading_error = atan2(target_dy, target_dx) - rover_heading
|
| 98 |
+
- steering β clamp(heading_error * 2.5, -1, 1)
|
| 99 |
+
- thrust=1.0 for progress; brake=0 unless overshooting
|
| 100 |
+
- If nearest_obstacle < 10m, steer perpendicular to dodge\
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# =============================================================================
|
| 105 |
+
# Compact observation prompt builder
|
| 106 |
+
# =============================================================================
|
| 107 |
+
|
| 108 |
+
def build_compact_prompt(
|
| 109 |
+
task_id: str,
|
| 110 |
+
obs: dict[str, Any],
|
| 111 |
+
step_num: int,
|
| 112 |
+
max_steps: int,
|
| 113 |
+
) -> str:
|
| 114 |
+
"""
|
| 115 |
+
Build a token-efficient user prompt from an observation dict.
|
| 116 |
+
Designed to fit in ~100β120 tokens so system + user β€ 256.
|
| 117 |
+
"""
|
| 118 |
+
dx = obs["target_relative"]["x"]
|
| 119 |
+
dy = obs["target_relative"]["y"]
|
| 120 |
+
|
| 121 |
+
# Pre-compute heading error so the model doesn't need trig
|
| 122 |
+
target_heading = math.atan2(dy, dx)
|
| 123 |
+
raw_error = target_heading - obs["rover_heading"]
|
| 124 |
+
while raw_error > math.pi: raw_error -= 2 * math.pi
|
| 125 |
+
while raw_error <= -math.pi: raw_error += 2 * math.pi
|
| 126 |
+
|
| 127 |
+
suggested_steering = max(-1.0, min(1.0, raw_error * 2.5))
|
| 128 |
+
|
| 129 |
+
return (
|
| 130 |
+
f"TASK: {task_id} STEP: {step_num}/{max_steps}\n"
|
| 131 |
+
f"target_distance={obs['target_distance']:.1f}m "
|
| 132 |
+
f"heading_error={raw_error:.4f}rad\n"
|
| 133 |
+
f"battery={obs['battery_level']:.3f} "
|
| 134 |
+
f"nearest_obstacle={obs['nearest_obstacle_distance']:.1f}m "
|
| 135 |
+
f"terrain={obs['terrain_type']}\n"
|
| 136 |
+
f"suggested_steering={suggested_steering:.4f}\n"
|
| 137 |
+
f"Output your <action> JSON now."
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# =============================================================================
|
| 142 |
+
# Dataset generation β resets episodes and collects initial observations
|
| 143 |
+
# =============================================================================
|
| 144 |
+
|
| 145 |
+
TASK_MAX_STEPS = {"easy": 200, "medium": 300, "hard": 100}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _check_server() -> None:
|
| 149 |
+
"""Fail fast if the environment server is unreachable."""
|
| 150 |
+
try:
|
| 151 |
+
r = requests.get(f"{SERVER_URL}/tasks", timeout=5)
|
| 152 |
+
r.raise_for_status()
|
| 153 |
+
log.info("Environment server is live at %s", SERVER_URL)
|
| 154 |
+
except Exception as e:
|
| 155 |
+
log.error(
|
| 156 |
+
"Cannot reach environment server at %s β "
|
| 157 |
+
"start it with: uvicorn main:app --host 0.0.0.0 --port 7860",
|
| 158 |
+
SERVER_URL,
|
| 159 |
+
)
|
| 160 |
+
sys.exit(1)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def generate_training_dataset(n_per_task: int = NUM_TRAIN_EPISODES) -> Dataset:
|
| 164 |
+
"""
|
| 165 |
+
Generate a training dataset by resetting episodes across all tasks.
|
| 166 |
+
|
| 167 |
+
Each row contains:
|
| 168 |
+
prompt β chat-formatted messages (system + user)
|
| 169 |
+
task_id β for environment reward replay
|
| 170 |
+
seed β for environment reward replay
|
| 171 |
+
"""
|
| 172 |
+
rows: list[dict[str, Any]] = []
|
| 173 |
+
|
| 174 |
+
for task_id in ["easy", "medium", "hard"]:
|
| 175 |
+
max_steps = TASK_MAX_STEPS[task_id]
|
| 176 |
+
for seed in range(n_per_task):
|
| 177 |
+
try:
|
| 178 |
+
resp = requests.post(
|
| 179 |
+
f"{SERVER_URL}/reset",
|
| 180 |
+
json={"task_id": task_id, "seed": seed},
|
| 181 |
+
timeout=10,
|
| 182 |
+
)
|
| 183 |
+
resp.raise_for_status()
|
| 184 |
+
data = resp.json()
|
| 185 |
+
except Exception as e:
|
| 186 |
+
log.warning("Reset failed (task=%s seed=%d): %s", task_id, seed, e)
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
obs = data["obs"]
|
| 190 |
+
user_msg = build_compact_prompt(task_id, obs, step_num=1, max_steps=max_steps)
|
| 191 |
+
|
| 192 |
+
rows.append({
|
| 193 |
+
"prompt": [
|
| 194 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 195 |
+
{"role": "user", "content": user_msg},
|
| 196 |
+
],
|
| 197 |
+
"task_id": task_id,
|
| 198 |
+
"seed": seed,
|
| 199 |
+
})
|
| 200 |
+
|
| 201 |
+
random.shuffle(rows)
|
| 202 |
+
log.info("Generated %d training prompts (%d per task Γ 3 tasks)", len(rows), n_per_task)
|
| 203 |
+
return Dataset.from_list(rows)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# =============================================================================
|
| 207 |
+
# Reward Function 1 β Format Gatekeeper
|
| 208 |
+
# =============================================================================
|
| 209 |
+
|
| 210 |
+
# Regex to extract content between <action> and </action> tags
|
| 211 |
+
_ACTION_RE = re.compile(r"<action>\s*(.*?)\s*</action>", re.DOTALL)
|
| 212 |
+
|
| 213 |
+
# Required fields and their (min, max) bounds
|
| 214 |
+
_ACTION_FIELDS = {
|
| 215 |
+
"thrust": (0.0, 1.0),
|
| 216 |
+
"steering": (-1.0, 1.0),
|
| 217 |
+
"brake": (0, 1),
|
| 218 |
+
"vertical_thruster": (-0.2, 0.2),
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def parse_action_from_completion(text: str) -> dict[str, Any] | None:
|
| 223 |
+
"""
|
| 224 |
+
Extract and validate an action JSON from <action>β¦</action> tags.
|
| 225 |
+
|
| 226 |
+
Returns the parsed action dict if valid, None otherwise.
|
| 227 |
+
"""
|
| 228 |
+
match = _ACTION_RE.search(text)
|
| 229 |
+
if not match:
|
| 230 |
+
return None
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
parsed = json.loads(match.group(1))
|
| 234 |
+
except json.JSONDecodeError:
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
if not isinstance(parsed, dict):
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
# Validate required fields exist and are numeric
|
| 241 |
+
action: dict[str, Any] = {}
|
| 242 |
+
for field, (lo, hi) in _ACTION_FIELDS.items():
|
| 243 |
+
if field not in parsed:
|
| 244 |
+
return None
|
| 245 |
+
val = parsed[field]
|
| 246 |
+
try:
|
| 247 |
+
if field == "brake":
|
| 248 |
+
val = int(round(float(val)))
|
| 249 |
+
else:
|
| 250 |
+
val = float(val)
|
| 251 |
+
except (TypeError, ValueError):
|
| 252 |
+
return None
|
| 253 |
+
# Reject wildly out-of-range (mild overshoot is clamped, not rejected)
|
| 254 |
+
if val < lo - 0.5 or val > hi + 0.5:
|
| 255 |
+
return None
|
| 256 |
+
# Clamp to valid bounds
|
| 257 |
+
if field == "brake":
|
| 258 |
+
val = max(0, min(1, val))
|
| 259 |
+
else:
|
| 260 |
+
val = max(lo, min(hi, val))
|
| 261 |
+
action[field] = val
|
| 262 |
+
|
| 263 |
+
return action
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def format_reward_fn(completions: list[str], **kwargs) -> list[float]:
|
| 267 |
+
"""
|
| 268 |
+
Reward Function 1 β The Format Gatekeeper.
|
| 269 |
+
|
| 270 |
+
Returns 1.0 if the completion contains valid <action>JSON</action>
|
| 271 |
+
matching the rover action schema. Returns 0.0 on failure.
|
| 272 |
+
|
| 273 |
+
Applies a soft verbosity penalty: completions exceeding
|
| 274 |
+
VERBOSITY_THRESHOLD tokens are penalised linearly, reaching 0
|
| 275 |
+
at VERBOSITY_THRESHOLD + VERBOSITY_PENALTY_K tokens.
|
| 276 |
+
"""
|
| 277 |
+
rewards: list[float] = []
|
| 278 |
+
|
| 279 |
+
for text in completions:
|
| 280 |
+
action = parse_action_from_completion(text)
|
| 281 |
+
if action is None:
|
| 282 |
+
rewards.append(FORMAT_REWARD_BAD)
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
# Base reward for valid format
|
| 286 |
+
base = FORMAT_REWARD_GOOD
|
| 287 |
+
|
| 288 |
+
# Soft verbosity penalty β count whitespace-split "tokens" as proxy
|
| 289 |
+
# (actual BPE count varies, but this is a stable heuristic)
|
| 290 |
+
token_estimate = len(text.split())
|
| 291 |
+
if token_estimate > VERBOSITY_THRESHOLD:
|
| 292 |
+
excess = token_estimate - VERBOSITY_THRESHOLD
|
| 293 |
+
penalty = max(0.0, 1.0 - excess / VERBOSITY_PENALTY_K)
|
| 294 |
+
base *= penalty
|
| 295 |
+
|
| 296 |
+
rewards.append(base)
|
| 297 |
+
|
| 298 |
+
return rewards
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# =============================================================================
|
| 302 |
+
# Reward Function 2 β Environment Reward
|
| 303 |
+
# =============================================================================
|
| 304 |
+
|
| 305 |
+
def environment_reward_fn(completions: list[str], **kwargs) -> list[float]:
|
| 306 |
+
"""
|
| 307 |
+
Reward Function 2 β The Environment.
|
| 308 |
+
|
| 309 |
+
For each completion:
|
| 310 |
+
1. Parse the action from <action> tags.
|
| 311 |
+
2. Reset a fresh episode with the same (task_id, seed) as the prompt.
|
| 312 |
+
3. POST the action to /step.
|
| 313 |
+
4. Return the scalar step reward from the physics engine.
|
| 314 |
+
|
| 315 |
+
If parsing or HTTP fails, returns 0.0 (neutral β no signal).
|
| 316 |
+
"""
|
| 317 |
+
task_ids: list[str] = kwargs.get("task_id", [])
|
| 318 |
+
seeds: list[int] = kwargs.get("seed", [])
|
| 319 |
+
|
| 320 |
+
rewards: list[float] = []
|
| 321 |
+
|
| 322 |
+
for i, text in enumerate(completions):
|
| 323 |
+
# -- Parse action --------------------------------------------------
|
| 324 |
+
action = parse_action_from_completion(text)
|
| 325 |
+
if action is None:
|
| 326 |
+
rewards.append(0.0)
|
| 327 |
+
continue
|
| 328 |
+
|
| 329 |
+
# -- Determine episode parameters ----------------------------------
|
| 330 |
+
# kwargs columns are lists aligned with completions.
|
| 331 |
+
# With num_generations=4, each prompt's metadata is repeated 4 times.
|
| 332 |
+
task_id = task_ids[i] if i < len(task_ids) else "easy"
|
| 333 |
+
seed = seeds[i] if i < len(seeds) else 0
|
| 334 |
+
|
| 335 |
+
try:
|
| 336 |
+
# Reset a fresh episode with the same seed β identical starting state
|
| 337 |
+
reset_resp = requests.post(
|
| 338 |
+
f"{SERVER_URL}/reset",
|
| 339 |
+
json={"task_id": task_id, "seed": seed},
|
| 340 |
+
timeout=10,
|
| 341 |
+
)
|
| 342 |
+
reset_resp.raise_for_status()
|
| 343 |
+
episode_id = reset_resp.json()["episode_id"]
|
| 344 |
+
|
| 345 |
+
# Step with the generated action
|
| 346 |
+
step_resp = requests.post(
|
| 347 |
+
f"{SERVER_URL}/step",
|
| 348 |
+
json=action,
|
| 349 |
+
params={"episode_id": episode_id},
|
| 350 |
+
timeout=10,
|
| 351 |
+
)
|
| 352 |
+
step_resp.raise_for_status()
|
| 353 |
+
step_data = step_resp.json()
|
| 354 |
+
|
| 355 |
+
# Return the scalar reward from the physics engine
|
| 356 |
+
reward = float(step_data.get("reward", 0.0))
|
| 357 |
+
rewards.append(reward)
|
| 358 |
+
|
| 359 |
+
except Exception as e:
|
| 360 |
+
log.warning("Environment reward failed (task=%s seed=%d): %s", task_id, seed, e)
|
| 361 |
+
rewards.append(0.0)
|
| 362 |
+
|
| 363 |
+
return rewards
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# =============================================================================
|
| 367 |
+
# Model loading
|
| 368 |
+
# =============================================================================
|
| 369 |
+
|
| 370 |
+
def load_model():
|
| 371 |
+
"""
|
| 372 |
+
Load Llama-3.2-1B-Instruct with Unsloth's 4-bit NF4 quantisation
|
| 373 |
+
and attach LoRA adapters to attention + MLP projections.
|
| 374 |
+
"""
|
| 375 |
+
log.info("Loading %s with 4-bit NF4 quantisation via Unslothβ¦", MODEL_NAME)
|
| 376 |
+
|
| 377 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 378 |
+
model_name = MODEL_NAME,
|
| 379 |
+
max_seq_length = MAX_SEQ_LENGTH,
|
| 380 |
+
dtype = None, # auto-detect (float16 on Ampere)
|
| 381 |
+
load_in_4bit = True, # NF4 quantisation for 6 GB VRAM
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
log.info("Attaching LoRA (rank=%d, alpha=%d) to attention + MLPβ¦", LORA_RANK, LORA_ALPHA)
|
| 385 |
+
|
| 386 |
+
model = FastLanguageModel.get_peft_model(
|
| 387 |
+
model,
|
| 388 |
+
r = LORA_RANK,
|
| 389 |
+
target_modules = [
|
| 390 |
+
# Attention projections
|
| 391 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 392 |
+
# MLP projections (SwiGLU in Llama)
|
| 393 |
+
"gate_proj", "up_proj", "down_proj",
|
| 394 |
+
],
|
| 395 |
+
lora_alpha = LORA_ALPHA,
|
| 396 |
+
lora_dropout = LORA_DROPOUT,
|
| 397 |
+
bias = "none",
|
| 398 |
+
use_gradient_checkpointing = "unsloth", # 60% less VRAM
|
| 399 |
+
random_state = SEED,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Ensure pad token is set (required for batched generation)
|
| 403 |
+
if tokenizer.pad_token is None:
|
| 404 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 405 |
+
tokenizer.padding_side = "left" # decoder-only: pad on the left
|
| 406 |
+
|
| 407 |
+
vram_gb = torch.cuda.memory_allocated() / 1e9
|
| 408 |
+
log.info("Model loaded. Current VRAM: %.2f GB", vram_gb)
|
| 409 |
+
|
| 410 |
+
return model, tokenizer
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
# =============================================================================
|
| 414 |
+
# Training configuration
|
| 415 |
+
# =============================================================================
|
| 416 |
+
|
| 417 |
+
def build_training_config() -> GRPOConfig:
|
| 418 |
+
"""Build the GRPOConfig with parameters safe for 6 GB VRAM."""
|
| 419 |
+
return GRPOConfig(
|
| 420 |
+
output_dir = OUTPUT_DIR,
|
| 421 |
+
|
| 422 |
+
# ββ GRPO-specific βββββββββββββββββββββββββββββββββββββββββββββ
|
| 423 |
+
num_generations = NUM_GENERATIONS, # group size = 4
|
| 424 |
+
max_prompt_length = MAX_PROMPT_LENGTH, # 256 tokens
|
| 425 |
+
max_completion_length = MAX_COMPLETION_LENGTH,# 256 tokens
|
| 426 |
+
beta = KL_COEF, # KL penalty coeff
|
| 427 |
+
|
| 428 |
+
# οΏ½οΏ½οΏ½β Optimiser βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 429 |
+
learning_rate = LEARNING_RATE, # 1e-6
|
| 430 |
+
lr_scheduler_type = "cosine",
|
| 431 |
+
warmup_ratio = 0.05,
|
| 432 |
+
max_grad_norm = 1.0,
|
| 433 |
+
|
| 434 |
+
# ββ Batch / accumulation ββββββββββββββββββββββββββββββββββββββ
|
| 435 |
+
per_device_train_batch_size = PER_DEVICE_BATCH, # 1 for 6 GB
|
| 436 |
+
gradient_accumulation_steps = GRAD_ACCUM_STEPS, # effective batch = 4
|
| 437 |
+
num_train_epochs = NUM_TRAIN_EPOCHS,
|
| 438 |
+
|
| 439 |
+
# ββ Precision / memory ββββββββββββββββββββββββββββββββββββββββ
|
| 440 |
+
bf16 = torch.cuda.is_bf16_supported(),
|
| 441 |
+
fp16 = not torch.cuda.is_bf16_supported(),
|
| 442 |
+
|
| 443 |
+
# ββ Logging / saving ββββββββββββββββββββββββββββββββββββββββββ
|
| 444 |
+
logging_steps = 5,
|
| 445 |
+
save_steps = 50,
|
| 446 |
+
save_total_limit = 3,
|
| 447 |
+
report_to = "none", # set to "wandb" if desired
|
| 448 |
+
seed = SEED,
|
| 449 |
+
|
| 450 |
+
# ββ Misc ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 451 |
+
remove_unused_columns = False, # keep task_id/seed cols
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# =============================================================================
|
| 456 |
+
# Main entry point
|
| 457 |
+
# =============================================================================
|
| 458 |
+
|
| 459 |
+
def main() -> None:
|
| 460 |
+
log.info("=" * 60)
|
| 461 |
+
log.info("GRPO Training β Planetary Rover Navigation")
|
| 462 |
+
log.info("Model : %s", MODEL_NAME)
|
| 463 |
+
log.info("VRAM : 6 GB target (4-bit NF4, LoRA r=%d, group=%d)",
|
| 464 |
+
LORA_RANK, NUM_GENERATIONS)
|
| 465 |
+
log.info("=" * 60)
|
| 466 |
+
|
| 467 |
+
# ββ 0. Check server βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 468 |
+
_check_server()
|
| 469 |
+
|
| 470 |
+
# ββ 1. Load model + tokenizer βββββββββββββββββββββββββββββββββββββ
|
| 471 |
+
model, tokenizer = load_model()
|
| 472 |
+
|
| 473 |
+
# ββ 2. Generate training dataset ββββββββββββββββββββββββββββββββββ
|
| 474 |
+
SYSTEM_PROMPT = """You are an autonomous planetary rover. Navigate to the target.
|
| 475 |
+
Output strictly valid JSON inside <action> tags with 'thrust' (0.0 to 1.0) and 'steering' (-1.0 to 1.0)."""
|
| 476 |
+
|
| 477 |
+
practice_prompts = {
|
| 478 |
+
"prompt": [
|
| 479 |
+
[
|
| 480 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 481 |
+
{"role": "user", "content": "Telemetry linked. Generate initial navigation action."}
|
| 482 |
+
]
|
| 483 |
+
] * 100
|
| 484 |
+
}
|
| 485 |
+
train_dataset = Dataset.from_dict(practice_prompts)
|
| 486 |
+
|
| 487 |
+
# ββ 3. Build GRPO config ββββββββββββββββββββββββββββββββββββββββββ
|
| 488 |
+
config = build_training_config()
|
| 489 |
+
|
| 490 |
+
# ββ 4. Initialise trainer βββββββββββββββββββββββββββββββββββββββββ
|
| 491 |
+
log.info("Initialising GRPOTrainer with 2 reward functionsβ¦")
|
| 492 |
+
trainer = GRPOTrainer(
|
| 493 |
+
model = model,
|
| 494 |
+
tokenizer = tokenizer,
|
| 495 |
+
reward_funcs = [format_reward_fn, environment_reward_fn],
|
| 496 |
+
args = config,
|
| 497 |
+
train_dataset = train_dataset,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# ββ 5. Train ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 501 |
+
log.info("Starting GRPO trainingβ¦")
|
| 502 |
+
start = time.time()
|
| 503 |
+
trainer.train()
|
| 504 |
+
elapsed = time.time() - start
|
| 505 |
+
log.info("Training complete in %.1f minutes.", elapsed / 60)
|
| 506 |
+
|
| 507 |
+
# ββ 6. Save final adapter βββββββββββββββββββββββββββββββββββββββββ
|
| 508 |
+
final_path = os.path.join(OUTPUT_DIR, "final_adapter")
|
| 509 |
+
model.save_pretrained(final_path)
|
| 510 |
+
tokenizer.save_pretrained(final_path)
|
| 511 |
+
log.info("Final LoRA adapter saved to %s", final_path)
|
| 512 |
+
|
| 513 |
+
# ββ 7. VRAM summary ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 514 |
+
peak_vram = torch.cuda.max_memory_allocated() / 1e9
|
| 515 |
+
log.info("Peak VRAM usage: %.2f GB (limit: 6.00 GB)", peak_vram)
|
| 516 |
+
if peak_vram > 6.0:
|
| 517 |
+
log.warning("β Peak VRAM exceeded 6 GB! Reduce NUM_GENERATIONS or LORA_RANK.")
|
| 518 |
+
else:
|
| 519 |
+
log.info("β
VRAM stayed within 6 GB budget.")
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
if __name__ == "__main__":
|
| 523 |
+
main()
|