Update README.md
Browse files
README.md
CHANGED
|
@@ -1,5 +1,71 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fine-Tuning Qwen2.5-3B-Instruct with GRPO for GSM8K Dataset
|
| 2 |
+
|
| 3 |
+
This notebook demonstrates the process of fine-tuning the **Qwen2.5-3B-Instruct** model using **GRPO (Generalized Reward Policy Optimization)** on the **GSM8K** dataset. The goal is to improve the model's ability to solve mathematical reasoning problems by leveraging reinforcement learning with custom reward functions.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The notebook is structured as follows:
|
| 8 |
+
|
| 9 |
+
1. **Installation**: Installs necessary libraries such as `unsloth`, `vllm`, and `trl` for efficient fine-tuning and inference.
|
| 10 |
+
2. **Unsloth Setup**: Configures the environment for faster fine-tuning using Unsloth's `PatchFastRL` and loads the Qwen2.5-3B-Instruct model with LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning.
|
| 11 |
+
3. **Data Preparation**: Loads and preprocesses the GSM8K dataset, formatting it for training with a system prompt and XML-style reasoning and answer format.
|
| 12 |
+
4. **Reward Functions**: Defines custom reward functions to evaluate the model's responses, including:
|
| 13 |
+
- **Correctness Reward**: Checks if the extracted answer matches the ground truth.
|
| 14 |
+
- **Format Reward**: Ensures the response follows the specified XML format.
|
| 15 |
+
- **Integer Reward**: Verifies if the extracted answer is an integer.
|
| 16 |
+
- **XML Count Reward**: Evaluates the completeness of the XML structure in the response.
|
| 17 |
+
5. **GRPO Training**: Configures and runs the GRPO trainer with vLLM for fast inference, using the defined reward functions to optimize the model's performance.
|
| 18 |
+
6. **Training Progress**: Monitors the training progress, including rewards, completion length, and KL divergence, to ensure the model is improving over time.
|
| 19 |
+
|
| 20 |
+
## Key Features
|
| 21 |
+
|
| 22 |
+
- **Efficient Fine-Tuning**: Utilizes Unsloth and LoRA to fine-tune the model with reduced memory usage and faster training times.
|
| 23 |
+
- **Custom Reward Functions**: Implements multiple reward functions to guide the model towards generating correct and well-formatted responses.
|
| 24 |
+
- **vLLM Integration**: Leverages vLLM for fast inference during training, enabling efficient generation of multiple responses for reward calculation.
|
| 25 |
+
- **GSM8K Dataset**: Focuses on improving the model's performance on mathematical reasoning tasks, specifically the GSM8K dataset.
|
| 26 |
+
|
| 27 |
+
## Requirements
|
| 28 |
+
|
| 29 |
+
- Python 3.11
|
| 30 |
+
- Libraries: `unsloth`, `vllm`, `trl`, `torch`, `transformers`
|
| 31 |
+
|
| 32 |
+
## Installation
|
| 33 |
+
|
| 34 |
+
To set up the environment, run:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
pip install unsloth vllm trl
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Usage
|
| 41 |
+
- **Load the Model**: The notebook loads the Qwen2.5-3B-Instruct model with LoRA for fine-tuning.
|
| 42 |
+
|
| 43 |
+
- **Prepare the Dataset**: The GSM8K dataset is loaded and formatted with a system prompt and XML-style reasoning and answer format.
|
| 44 |
+
|
| 45 |
+
- **Define Reward Functions**: Custom reward functions are defined to evaluate the model's responses.
|
| 46 |
+
|
| 47 |
+
- **Train the Model**: The GRPO trainer is configured and run to fine-tune the model using the defined reward functions.
|
| 48 |
+
|
| 49 |
+
- **Monitor Progress**: The training progress is monitored, including rewards, completion length, and KL divergence.
|
| 50 |
+
|
| 51 |
+
## Results
|
| 52 |
+
- The training process is designed to improve the model's ability to generate correct and well-formatted responses to mathematical reasoning problems. The reward functions guide the model towards better performance, and the training progress is logged for analysis.
|
| 53 |
+
|
| 54 |
+
## Future Work
|
| 55 |
+
- **Hyperparameter Tuning**: Experiment with different learning rates, batch sizes, and reward weights to optimize performance.
|
| 56 |
+
|
| 57 |
+
- **Additional Datasets**: Extend the fine-tuning process to other datasets to improve the model's generalization capabilities.
|
| 58 |
+
|
| 59 |
+
- **Advanced Reward Functions**: Implement more sophisticated reward functions to further refine the model's responses.
|
| 60 |
+
|
| 61 |
+
## Acknowledgments
|
| 62 |
+
- **Unsloth**: For providing tools to speed up fine-tuning.
|
| 63 |
+
|
| 64 |
+
- **vLLM**: For enabling fast inference during training.
|
| 65 |
+
|
| 66 |
+
- **Hugging Face**: For the trl library and the GSM8K dataset.
|
| 67 |
+
|
| 68 |
+
- Special thanks to @sudhir2016 sir for mentoring me for developing such a prominent fine-tuning model.
|
| 69 |
+
|
| 70 |
+
## License
|
| 71 |
+
This project is licensed under the MIT License. See the LICENSE file for details.
|