1.83 GB
880 files
Updated 2 months ago
Name
Size
__pycache__
README.md3.01 kB
xet
train_grpo.py19.6 kB
xet
train_sft.py6.59 kB
xet
README.md

CropRL Training Pipeline

This directory contains the complete pipeline for training a small LLM (e.g., 0.6B) to play the CropRL environment using a Teacher-Student distillation approach followed by Reinforcement Learning.

Pipeline Overview

1. Data Generation

Script: dataset/generate_sft_data.py We use a high-capacity Teacher LLM (like a 7B, 14B, or 72B model) to play the game and generate high-quality reasoning and strategies.

python ../dataset/generate_sft_data.py --model_name Qwen/Qwen3-8B --num_episodes 100 --output_file dataset/sft_data_1.jsonl

Note: This script overwrites the output file to prevent duplicates. To generate large datasets, run it multiple times with different --seed_base and --output_file arguments.

2. Data Filtering

Script: dataset/filter_sft_data.py Not all episodes generated by the Teacher will be successful. We filter the raw JSONL dataset to only keep trajectories where the agent achieved a high total_return (change in net worth).

python ../dataset/filter_sft_data.py --input_file dataset/sft_data_1.jsonl --output_file dataset/sft_filtered.jsonl --min_reward 1000.0

3. Supervised Fine-Tuning (SFT)

Script: train_sft.py We perform SFT on the small target model (e.g., 0.6B) using the filtered high-quality trajectories. This teaches the model the rules of the game, the formatting requirements, and basic strategies. The LoRA weights are merged into the base model at the end.

python train_sft.py --data_path ../dataset/sft_filtered.jsonl

4. Reinforcement Learning (GRPO)

Script: train_grpo.py We load the SFT-merged model and train it using GRPO (Group Relative Policy Optimization). The model plays the game against itself in parallel environments, using constrained decoding to strictly output valid actions.

python train_grpo.py --model_name ./sft_merged_model

Design Note: LoRA Target Modules

You will notice that the LoRA target_modules differ between the SFT and GRPO scripts. This is intentional.

  • SFT Target Modules (all-linear): In SFT, we target ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]. Teaching a base model a completely new domain, action formatting, and reasoning pattern requires significant parameter capacity. Targeting all linear layers allows the model to deeply absorb the Teacher's knowledge.
  • GRPO Target Modules (q_proj, v_proj): In GRPO, we restrict the targets to just ["q_proj", "v_proj"]. Reinforcement Learning algorithms are notoriously unstable and sensitive to large gradient updates. If we apply LoRA to all layers during RL, the policy can easily collapse or suffer from catastrophic forgetting of the language constraints learned in SFT. Constraining the updates to the query and value projections provides just enough capacity for the model to adjust its policy probabilities (RL tuning) while protecting the core world-model learned during SFT.
Total size
1.83 GB
Files
880
Last updated
Apr 25
Pre-warmed CDN
US EU US EU

Contributors