harshraj22's picture
|
download
raw
3.01 kB
# CropRL Training Pipeline
This directory contains the complete pipeline for training a small LLM (e.g., 0.6B) to play the CropRL environment using a Teacher-Student distillation approach followed by Reinforcement Learning.
## Pipeline Overview
### 1. Data Generation
**Script**: `dataset/generate_sft_data.py`
We use a high-capacity Teacher LLM (like a 7B, 14B, or 72B model) to play the game and generate high-quality reasoning and strategies.
```bash
python ../dataset/generate_sft_data.py --model_name Qwen/Qwen3-8B --num_episodes 100 --output_file dataset/sft_data_1.jsonl
```
*Note: This script overwrites the output file to prevent duplicates. To generate large datasets, run it multiple times with different `--seed_base` and `--output_file` arguments.*
### 2. Data Filtering
**Script**: `dataset/filter_sft_data.py`
Not all episodes generated by the Teacher will be successful. We filter the raw JSONL dataset to only keep trajectories where the agent achieved a high `total_return` (change in net worth).
```bash
python ../dataset/filter_sft_data.py --input_file dataset/sft_data_1.jsonl --output_file dataset/sft_filtered.jsonl --min_reward 1000.0
```
### 3. Supervised Fine-Tuning (SFT)
**Script**: `train_sft.py`
We perform SFT on the small target model (e.g., 0.6B) using the filtered high-quality trajectories. This teaches the model the rules of the game, the formatting requirements, and basic strategies. The LoRA weights are merged into the base model at the end.
```bash
python train_sft.py --data_path ../dataset/sft_filtered.jsonl
```
### 4. Reinforcement Learning (GRPO)
**Script**: `train_grpo.py`
We load the SFT-merged model and train it using GRPO (Group Relative Policy Optimization). The model plays the game against itself in parallel environments, using constrained decoding to strictly output valid actions.
```bash
python train_grpo.py --model_name ./sft_merged_model
```
---
## Design Note: LoRA Target Modules
You will notice that the LoRA `target_modules` differ between the SFT and GRPO scripts. **This is intentional.**
* **SFT Target Modules (`all-linear`)**: In SFT, we target `["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]`. Teaching a base model a completely new domain, action formatting, and reasoning pattern requires significant parameter capacity. Targeting all linear layers allows the model to deeply absorb the Teacher's knowledge.
* **GRPO Target Modules (`q_proj`, `v_proj`)**: In GRPO, we restrict the targets to just `["q_proj", "v_proj"]`. Reinforcement Learning algorithms are notoriously unstable and sensitive to large gradient updates. If we apply LoRA to all layers during RL, the policy can easily collapse or suffer from catastrophic forgetting of the language constraints learned in SFT. Constraining the updates to the query and value projections provides just enough capacity for the model to adjust its policy probabilities (RL tuning) while protecting the core world-model learned during SFT.

Xet Storage Details

Size:
3.01 kB
·
Xet hash:
cec1d9270f501d8c330768563e31304138c339e92328c5ddc2d1bf980c5297c9

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.